Exemple #1
0
def train_ae(ae_trainer,
             training_distrib,
             num_epochs=100,
             num_batches_per_epoch=500,
             batch_size=512,
             goal_key='image_desired_goal',
             rl_csv_fname='progress.csv'):
    from rlkit.core import logger

    logger.remove_tabular_output(rl_csv_fname, relative_to_snapshot_dir=True)
    logger.add_tabular_output('ae_progress.csv', relative_to_snapshot_dir=True)

    for epoch in range(num_epochs):
        for batch_num in range(num_batches_per_epoch):
            goals = ptu.from_numpy(
                training_distrib.sample(batch_size)[goal_key])
            batch = dict(raw_next_observations=goals, )
            ae_trainer.train_from_torch(batch)
        log = OrderedDict()
        log['epoch'] = epoch
        append_log(log, ae_trainer.eval_statistics, prefix='ae/')
        logger.record_dict(log)
        logger.dump_tabular(with_prefix=True, with_timestamp=False)
        ae_trainer.end_epoch(epoch)

    logger.add_tabular_output(rl_csv_fname, relative_to_snapshot_dir=True)
    logger.remove_tabular_output('ae_progress.csv',
                                 relative_to_snapshot_dir=True)
def train_vae_and_update_variant(variant):
    from rlkit.core import logger

    skewfit_variant = variant["skewfit_variant"]
    train_vae_variant = variant["train_vae_variant"]
    if skewfit_variant.get("vae_path", None) is None:
        logger.remove_tabular_output("progress.csv",
                                     relative_to_snapshot_dir=True)
        logger.add_tabular_output("vae_progress.csv",
                                  relative_to_snapshot_dir=True)
        vae, vae_train_data, vae_test_data = train_vae(train_vae_variant,
                                                       return_data=True)
        if skewfit_variant.get("save_vae_data", False):
            skewfit_variant["vae_train_data"] = vae_train_data
            skewfit_variant["vae_test_data"] = vae_test_data
        logger.save_extra_data(vae, "vae.pkl", mode="pickle")
        logger.remove_tabular_output("vae_progress.csv",
                                     relative_to_snapshot_dir=True)
        logger.add_tabular_output("progress.csv",
                                  relative_to_snapshot_dir=True)
        skewfit_variant["vae_path"] = vae  # just pass the VAE directly
    else:
        if skewfit_variant.get("save_vae_data", False):
            vae_train_data, vae_test_data, info = generate_vae_dataset(
                train_vae_variant["generate_vae_dataset_kwargs"])
            skewfit_variant["vae_train_data"] = vae_train_data
            skewfit_variant["vae_test_data"] = vae_test_data
def train_vae_and_update_variant(variant):
    from rlkit.core import logger
    skewfit_variant = variant['skewfit_variant']
    train_vae_variant = variant['train_vae_variant']
    if skewfit_variant.get('vae_path', None) is None:
        logger.remove_tabular_output('progress.csv',
                                     relative_to_snapshot_dir=True)
        logger.add_tabular_output('vae_progress.csv',
                                  relative_to_snapshot_dir=True)
        vae, vae_train_data, vae_test_data = train_vae(
            train_vae_variant, variant['other_variant'], return_data=True)
        if skewfit_variant.get('save_vae_data', False):
            skewfit_variant['vae_train_data'] = vae_train_data
            skewfit_variant['vae_test_data'] = vae_test_data
        logger.save_extra_data(vae, 'vae.pkl', mode='pickle')
        logger.remove_tabular_output(
            'vae_progress.csv',
            relative_to_snapshot_dir=True,
        )
        logger.add_tabular_output(
            'progress.csv',
            relative_to_snapshot_dir=True,
        )
        skewfit_variant['vae_path'] = vae  # just pass the VAE directly
    else:
        if skewfit_variant.get('save_vae_data', False):
            vae_train_data, vae_test_data, info = generate_vae_dataset(
                train_vae_variant['generate_vae_dataset_kwargs'])
            skewfit_variant['vae_train_data'] = vae_train_data
            skewfit_variant['vae_test_data'] = vae_test_data
Exemple #4
0
def train_vae_and_update_config(cfgs):
    from rlkit.core import logger
    if cfgs.VAE_TRAINER.get('vae_path', None) is None:
        logger.remove_tabular_output('progress.csv',
                                     relative_to_snapshot_dir=True)
        logger.add_tabular_output('vae_progress.csv',
                                  relative_to_snapshot_dir=True)
        vae, vae_train_data, vae_test_data = train_vae(cfgs, return_data=True)
        cfgs.VAE_TRAINER.vae_path = vae  # just pass the VAE directly
        if cfgs.VAE_TRAINER.save_vae_data:
            cfgs.VAE_TRAINER.train_data = vae_train_data
            cfgs.VAE_TRAINER.test_data = vae_test_data

        logger.remove_tabular_output(
            'vae_progress.csv',
            relative_to_snapshot_dir=True,
        )
        logger.add_tabular_output(
            'progress.csv',
            relative_to_snapshot_dir=True,
        )
    else:
        vae_train_data, vae_test_data, _ = generate_vae_dataset(cfgs)
        if cfgs.VAE_TRAINER.save_vae_data:
            cfgs.VAE_TRAINER.train_data = vae_train_data
            cfgs.VAE_TRAINER.test_data = vae_test_data
Exemple #5
0
 def run(self):
     if self.progress_csv_file_name != 'progress.csv':
         logger.remove_tabular_output('progress.csv',
                                      relative_to_snapshot_dir=True)
         logger.add_tabular_output(
             self.progress_csv_file_name,
             relative_to_snapshot_dir=True,
         )
     timer.return_global_times = True
     for _ in range(self.num_iters):
         self._begin_epoch()
         timer.start_timer('saving')
         logger.save_itr_params(self.epoch, self._get_snapshot())
         timer.stop_timer('saving')
         log_dict, _ = self._train()
         logger.record_dict(log_dict)
         logger.dump_tabular(with_prefix=True, with_timestamp=False)
         self._end_epoch()
     logger.save_itr_params(self.epoch, self._get_snapshot())
     if self.progress_csv_file_name != 'progress.csv':
         logger.remove_tabular_output(
             self.progress_csv_file_name,
             relative_to_snapshot_dir=True,
         )
         logger.add_tabular_output(
             'progress.csv',
             relative_to_snapshot_dir=True,
         )
def continue_experiment(args):
    logger.add_text_output('./d_text.txt')
    logger.add_tabular_output('./d_tabular.txt')
    logger.set_snapshot_dir('./snaps')

    extra = joblib.load(args.extra)

    algorithm = extra['algorithm']
    algorithm.farmlist_base = [('0.0.0.0', 15)]
    algorithm.refarm()

    if ptu.gpu_enabled():
        algorithm.cuda()
    algorithm.train()
Exemple #7
0
    def pretrain_q_with_bc_data(self, batch_size):
        logger.remove_tabular_output('progress.csv',
                                     relative_to_snapshot_dir=True)
        logger.add_tabular_output('pretrain_q.csv',
                                  relative_to_snapshot_dir=True)

        self.update_policy = True
        # then train policy and Q function together
        prev_time = time.time()
        for i in range(self.num_pretrain_steps):
            self.eval_statistics = dict()
            if i % self.pretraining_logging_period == 0:
                self._need_to_update_eval_statistics = True
            train_data = self.replay_buffer.random_batch(batch_size)
            train_data = np_to_pytorch_batch(train_data)
            obs = train_data['observations']
            next_obs = train_data['next_observations']
            # goals = train_data['resampled_goals']
            train_data['observations'] = obs  # torch.cat((obs, goals), dim=1)
            train_data[
                'next_observations'] = next_obs  # torch.cat((next_obs, goals), dim=1)
            self.train_from_torch(train_data)
            if self.do_pretrain_rollouts and i % self.pretraining_env_logging_period == 0:
                total_ret = self.do_rollouts()
                print("Return at step {} : {}".format(i, total_ret / 20))

            if i % self.pretraining_logging_period == 0:
                if self.do_pretrain_rollouts:
                    self.eval_statistics[
                        "pretrain_bc/avg_return"] = total_ret / 20
                self.eval_statistics["batch"] = i
                self.eval_statistics["epoch_time"] = time.time() - prev_time
                logger.record_dict(self.eval_statistics)
                logger.dump_tabular(with_prefix=True, with_timestamp=False)
                prev_time = time.time()

        logger.remove_tabular_output(
            'pretrain_q.csv',
            relative_to_snapshot_dir=True,
        )
        logger.add_tabular_output(
            'progress.csv',
            relative_to_snapshot_dir=True,
        )

        self._need_to_update_eval_statistics = True
        self.eval_statistics = dict()
def experiment(variant):
    logger.add_text_output('./d_text.txt')
    logger.add_tabular_output('./d_tabular.txt')
    logger.set_snapshot_dir('./snaps')

    farmer = Farmer([('0.0.0.0', 1)])
    remote_env = farmer.force_acq_env()
    remote_env.set_spaces()
    env = NormalizedBoxEnv(remote_env)

    es = GaussianStrategy(
        action_space=env.action_space,
        max_sigma=0.1,
        min_sigma=0.1,  # Constant sigma
    )
    obs_dim = env.observation_space.low.size
    action_dim = env.action_space.low.size
    qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[256, 256],
    )
    qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[256, 256],
    )
    policy = TanhMlpPolicy(
        input_size=obs_dim,
        output_size=action_dim,
        hidden_sizes=[256, 256],
    )
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    algorithm = TD3(env,
                    qf1=qf1,
                    qf2=qf2,
                    policy=policy,
                    exploration_policy=exploration_policy,
                    **variant['algo_kwargs'])
    if ptu.gpu_enabled():
        algorithm.cuda()
    algorithm.train()
Exemple #9
0
def transfer_encoder_goal_conditioned_sac_experiment(train_variant,
                                                     transfer_variant=None):
    rl_csv_fname = 'progress.csv'
    retrain_rl_csv_fname = 'retrain_progress.csv'
    train_results = encoder_goal_conditioned_sac_experiment(
        **train_variant, rl_csv_fname=rl_csv_fname)
    logger.remove_tabular_output(rl_csv_fname, relative_to_snapshot_dir=True)
    logger.add_tabular_output(retrain_rl_csv_fname,
                              relative_to_snapshot_dir=True)
    if not transfer_variant:
        transfer_variant = {}

    transfer_variant = recursive_dict_update(train_variant, transfer_variant)
    print('Starting transfer exp')
    encoder_goal_conditioned_sac_experiment(**transfer_variant,
                                            is_retraining_from_scratch=True,
                                            train_results=train_results,
                                            rl_csv_fname=retrain_rl_csv_fname)
Exemple #10
0
    def pretrain_q_with_bc_data(self, batch_size):
        logger.remove_tabular_output('progress.csv',
                                     relative_to_snapshot_dir=True)
        logger.add_tabular_output('pretrain_q.csv',
                                  relative_to_snapshot_dir=True)

        prev_time = time.time()
        for i in range(self.num_pretrain_steps):
            self.eval_statistics = dict()
            if i % self.pretraining_logging_period == 0:
                self._need_to_update_eval_statistics = True
            train_data = self.replay_buffer.random_batch(self.bc_batch_size)
            train_data = np_to_pytorch_batch(train_data)
            obs = train_data['observations']
            next_obs = train_data['next_observations']
            # goals = train_data['resampled_goals']
            train_data['observations'] = obs  # torch.cat((obs, goals), dim=1)
            train_data[
                'next_observations'] = next_obs  # torch.cat((next_obs, goals), dim=1)
            self.train_from_torch(train_data, pretrain=True)

            if i % self.pretraining_logging_period == 0:
                self.eval_statistics["batch"] = i
                self.eval_statistics["epoch_time"] = time.time() - prev_time
                stats_with_prefix = add_prefix(self.eval_statistics,
                                               prefix="trainer/")
                logger.record_dict(stats_with_prefix)
                logger.dump_tabular(with_prefix=True, with_timestamp=False)
                prev_time = time.time()

        logger.remove_tabular_output(
            'pretrain_q.csv',
            relative_to_snapshot_dir=True,
        )
        logger.add_tabular_output(
            'progress.csv',
            relative_to_snapshot_dir=True,
        )

        self._need_to_update_eval_statistics = True
        self.eval_statistics = dict()
Exemple #11
0
def grill_her_full_experiment(variant, mode='td3'):
    train_vae_variant = variant['train_vae_variant']
    grill_variant = variant['grill_variant']
    env_class = variant['env_class']
    env_kwargs = variant['env_kwargs']
    init_camera = variant['init_camera']
    train_vae_variant['generate_vae_dataset_kwargs']['env_class'] = env_class
    train_vae_variant['generate_vae_dataset_kwargs']['env_kwargs'] = env_kwargs
    train_vae_variant['generate_vae_dataset_kwargs'][
        'init_camera'] = init_camera
    grill_variant['env_class'] = env_class
    grill_variant['env_kwargs'] = env_kwargs
    grill_variant['init_camera'] = init_camera
    if 'vae_paths' not in grill_variant:
        logger.remove_tabular_output('progress.csv',
                                     relative_to_snapshot_dir=True)
        logger.add_tabular_output('vae_progress.csv',
                                  relative_to_snapshot_dir=True)
        vae = train_vae(train_vae_variant)
        rdim = train_vae_variant['representation_size']
        vae_file = logger.save_extra_data(vae, 'vae.pkl', mode='pickle')
        logger.remove_tabular_output(
            'vae_progress.csv',
            relative_to_snapshot_dir=True,
        )
        logger.add_tabular_output(
            'progress.csv',
            relative_to_snapshot_dir=True,
        )
        grill_variant['vae_paths'] = {
            str(rdim): vae_file,
        }
        grill_variant['rdim'] = str(rdim)
    if mode == 'td3':
        grill_her_td3_experiment(variant['grill_variant'])
    elif mode == 'twin-sac':
        grill_her_twin_sac_experiment(variant['grill_variant'])
    elif mode == 'sac':
        grill_her_sac_experiment(variant['grill_variant'])
Exemple #12
0
def experiment(variant):
    logger.add_text_output('./d_text.txt')
    logger.add_tabular_output('./d_tabular.txt')
    logger.set_snapshot_dir('./snaps')
    farmer = Farmer([('0.0.0.0', 1)])

    remote_env = farmer.force_acq_env()
    remote_env.set_spaces()
    env = NormalizedBoxEnv(remote_env)

    es = OUStrategy(action_space=env.action_space)
    obs_dim = env.observation_space.low.size
    action_dim = env.action_space.low.size

    net_size = variant['net_size']

    qf = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[net_size, net_size],
    )
    policy = TanhMlpPolicy(
        input_size=obs_dim,
        output_size=action_dim,
        hidden_sizes=[net_size, net_size],
    )
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    algorithm = DDPG(env,
                     qf=qf,
                     policy=policy,
                     exploration_policy=exploration_policy,
                     **variant['algo_params'])
    if ptu.gpu_enabled():
        algorithm.cuda()
    algorithm.train()
def experiment(variant):
    logger.add_text_output('./d_text.txt')
    logger.add_tabular_output('./d_tabular.txt')
    logger.set_snapshot_dir('./snaps')
    farmer = Farmer([('0.0.0.0', 1)])
    remote_env = farmer.force_acq_env()
    remote_env.set_spaces()
    env = NormalizedBoxEnv(remote_env)

    obs_dim = int(np.prod(env.observation_space.shape))
    action_dim = int(np.prod(env.action_space.shape))

    net_size = variant['net_size']
    qf = FlattenMlp(
        hidden_sizes=[net_size, net_size],
        input_size=obs_dim + action_dim,
        output_size=1,
    )
    vf = FlattenMlp(
        hidden_sizes=[net_size, net_size],
        input_size=obs_dim,
        output_size=1,
    )
    policy = TanhGaussianPolicy(
        hidden_sizes=[net_size, net_size],
        obs_dim=obs_dim,
        action_dim=action_dim,
    )
    algorithm = SoftActorCritic(env=env,
                                training_env=env,
                                policy=policy,
                                qf=qf,
                                vf=vf,
                                **variant['algo_params'])
    if ptu.gpu_enabled():
        algorithm.cuda()
    algorithm.train()
Exemple #14
0
def train_vae(
        variant, env_kwargs, env_id, env_class, imsize, init_camera, return_data=False
    ):
        from rlkit.misc.ml_util import PiecewiseLinearSchedule, ConstantSchedule
        from rlkit.torch.vae.conv_vae import (
            ConvVAE,
            SpatialAutoEncoder,
            AutoEncoder,
        )
        import rlkit.torch.vae.conv_vae as conv_vae
        from rlkit.torch.vae.vae_trainer import ConvVAETrainer
        from rlkit.core import logger
        import rlkit.torch.pytorch_util as ptu
        from rlkit.pythonplusplus import identity
        import torch

        logger.remove_tabular_output(
            'progress.csv', relative_to_snapshot_dir=True
        )
        logger.add_tabular_output(
            'model_progress.csv', relative_to_snapshot_dir=True
        )

        beta = variant["beta"]
        representation_size = variant.get("representation_size",
            variant.get("latent_sizes", variant.get("embedding_dim", None)))
        use_linear_dynamics = variant.get('use_linear_dynamics', False)
        generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn',
                                                generate_vae_dataset)
        variant['generate_vae_dataset_kwargs']['use_linear_dynamics'] = use_linear_dynamics
        variant['generate_vae_dataset_kwargs']['batch_size'] = variant['algo_kwargs']['batch_size']
        train_dataset, test_dataset, info = generate_vae_dataset_fctn(
            env_kwargs, env_id, env_class, imsize, init_camera,
            **variant['generate_vae_dataset_kwargs']
        )

        if use_linear_dynamics:
            action_dim = train_dataset.data['actions'].shape[2]

        logger.save_extra_data(info)
        logger.get_snapshot_dir()
        if 'beta_schedule_kwargs' in variant:
            beta_schedule = PiecewiseLinearSchedule(
                **variant['beta_schedule_kwargs'])
        else:
            beta_schedule = None
        if 'context_schedule' in variant:
            schedule = variant['context_schedule']
            if type(schedule) is dict:
                context_schedule = PiecewiseLinearSchedule(**schedule)
            else:
                context_schedule = ConstantSchedule(schedule)
            variant['algo_kwargs']['context_schedule'] = context_schedule
        if variant.get('decoder_activation', None) == 'sigmoid':
            decoder_activation = torch.nn.Sigmoid()
        else:
            decoder_activation = identity
        architecture = variant['vae_kwargs'].get('architecture', None)
        if not architecture and imsize == 84:
            architecture = conv_vae.imsize84_default_architecture
        elif not architecture and imsize == 48:
            architecture = conv_vae.imsize48_default_architecture
        variant['vae_kwargs']['architecture'] = architecture
        variant['vae_kwargs']['imsize'] = imsize

        if variant['algo_kwargs'].get('is_auto_encoder', False):
            model = AutoEncoder(representation_size, decoder_output_activation=decoder_activation,**variant['vae_kwargs'])
        elif variant.get('use_spatial_auto_encoder', False):
            model = SpatialAutoEncoder(representation_size, decoder_output_activation=decoder_activation,**variant['vae_kwargs'])
        else:
            vae_class = variant.get('vae_class', ConvVAE)
            if use_linear_dynamics:
                model = vae_class(representation_size, decoder_output_activation=decoder_activation, action_dim=action_dim,**variant['vae_kwargs'])
            else:
                model = vae_class(representation_size, decoder_output_activation=decoder_activation,**variant['vae_kwargs'])
        model.to(ptu.device)

        vae_trainer_class = variant.get('vae_trainer_class', ConvVAETrainer)
        trainer = vae_trainer_class(model, beta=beta,
                           beta_schedule=beta_schedule,
                           **variant['algo_kwargs'])
        save_period = variant['save_period']

        dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False)
        for epoch in range(variant['num_epochs']):
            should_save_imgs = (epoch % save_period == 0)
            trainer.train_epoch(epoch, train_dataset)
            trainer.test_epoch(epoch, test_dataset)

            if should_save_imgs:
                trainer.dump_reconstructions(epoch)
                trainer.dump_samples(epoch)
                if dump_skew_debug_plots:
                    trainer.dump_best_reconstruction(epoch)
                    trainer.dump_worst_reconstruction(epoch)
                    trainer.dump_sampling_histogram(epoch)

            stats = trainer.get_diagnostics()
            for k, v in stats.items():
                logger.record_tabular(k, v)
            logger.dump_tabular()
            trainer.end_epoch(epoch)

            if epoch % 50 == 0:
                logger.save_itr_params(epoch, model)
        logger.save_extra_data(model, 'vae.pkl', mode='pickle')

        logger.remove_tabular_output(
            'model_progress.csv',
            relative_to_snapshot_dir=True,
        )
        logger.add_tabular_output(
            'progress.csv',
            relative_to_snapshot_dir=True,
        )
        if return_data:
            return model, train_dataset, test_dataset
        return model
def setup_logger(
        exp_prefix="default",
        variant=None,
        text_log_file="debug.log",
        variant_log_file="variant.json",
        tabular_log_file="progress.csv",
        snapshot_mode="last",
        snapshot_gap=1,
        log_tabular_only=False,
        log_dir=None,
        git_infos=None,
        script_name=None,
        **create_log_dir_kwargs
):
    """
    Set up logger to have some reasonable default settings.

    Will save log output to

        based_log_dir/exp_prefix/exp_name.

    exp_name will be auto-generated to be unique.

    If log_dir is specified, then that directory is used as the output dir.

    :param exp_prefix: The sub-directory for this specific experiment.
    :param variant:
    :param text_log_file:
    :param variant_log_file:
    :param tabular_log_file:
    :param snapshot_mode:
    :param log_tabular_only:
    :param snapshot_gap:
    :param log_dir:
    :param git_infos:
    :param script_name: If set, save the script name to this.
    :return:
    """
    if git_infos is None:
        git_infos = get_git_infos(conf.CODE_DIRS_TO_MOUNT)
    first_time = log_dir is None
    if first_time:
        log_dir = create_log_dir(exp_prefix, **create_log_dir_kwargs)

    if variant is not None:
        logger.log("Variant:")
        logger.log(json.dumps(dict_to_safe_json(variant), indent=2))
        variant_log_path = osp.join(log_dir, variant_log_file)
        logger.log_variant(variant_log_path, variant)

    tabular_log_path = osp.join(log_dir, tabular_log_file)
    text_log_path = osp.join(log_dir, text_log_file)

    logger.add_text_output(text_log_path)
    if first_time:
        logger.add_tabular_output(tabular_log_path)
    else:
        logger._add_output(tabular_log_path, logger._tabular_outputs,
                           logger._tabular_fds, mode='a')
        for tabular_fd in logger._tabular_fds:
            logger._tabular_header_written.add(tabular_fd)
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(snapshot_mode)
    logger.set_snapshot_gap(snapshot_gap)
    logger.set_log_tabular_only(log_tabular_only)
    exp_name = log_dir.split("/")[-1]
    logger.push_prefix("[%s] " % exp_name)

    if git_infos is not None:
        for (
            directory, code_diff, code_diff_staged, commit_hash, branch_name
        ) in git_infos:
            if directory[-1] == '/':
                directory = directory[:-1]
            diff_file_name = directory[1:].replace("/", "-") + ".patch"
            diff_staged_file_name = (
                directory[1:].replace("/", "-") + "_staged.patch"
            )
            if code_diff is not None and len(code_diff) > 0:
                with open(osp.join(log_dir, diff_file_name), "w") as f:
                    f.write(code_diff + '\n')
            if code_diff_staged is not None and len(code_diff_staged) > 0:
                with open(osp.join(log_dir, diff_staged_file_name), "w") as f:
                    f.write(code_diff_staged + '\n')
            with open(osp.join(log_dir, "git_infos.txt"), "a") as f:
                f.write("directory: {}\n".format(directory))
                f.write("git hash: {}\n".format(commit_hash))
                f.write("git branch name: {}\n\n".format(branch_name))
    if script_name is not None:
        with open(osp.join(log_dir, "script_name.txt"), "w") as f:
            f.write(script_name)
    return log_dir
def setup_logger(
    exp_prefix="default",
    exp_id=0,
    seed=0,
    variant=None,
    base_log_dir=None,
    text_log_file="debug.log",
    variant_log_file="variant.json",
    tabular_log_file="progress.csv",
    snapshot_mode="last",
    snapshot_gap=1,
    log_tabular_only=False,
    log_dir=None,
    git_info=None,
    script_name=None,
):
    """
    Set up logger to have some reasonable default settings.

    Will save log output to

        based_log_dir/exp_prefix/exp_name.

    exp_name will be auto-generated to be unique.

    If log_dir is specified, then that directory is used as the output dir.

    :param exp_prefix: The sub-directory for this specific experiment.
    :param exp_id: The number of the specific experiment run within this
    experiment.
    :param variant:
    :param base_log_dir: The directory where all log should be saved.
    :param text_log_file:
    :param variant_log_file:
    :param tabular_log_file:
    :param snapshot_mode:
    :param log_tabular_only:
    :param snapshot_gap:
    :param log_dir:
    :param git_info:
    :param script_name: If set, save the script name to this.
    :return:
    """
    first_time = log_dir is None
    if first_time:
        log_dir = create_log_dir(exp_prefix,
                                 exp_id=exp_id,
                                 seed=seed,
                                 base_log_dir=base_log_dir)

    if variant is not None:
        logger.log("Variant:")
        logger.log(json.dumps(dict_to_safe_json(variant), indent=2))
        variant_log_path = osp.join(log_dir, variant_log_file)
        logger.log_variant(variant_log_path, variant)

    tabular_log_path = osp.join(log_dir, tabular_log_file)
    text_log_path = osp.join(log_dir, text_log_file)

    logger.add_text_output(text_log_path)
    if first_time:
        logger.add_tabular_output(tabular_log_path)
    else:
        logger._add_output(tabular_log_path,
                           logger._tabular_outputs,
                           logger._tabular_fds,
                           mode='a')
        for tabular_fd in logger._tabular_fds:
            logger._tabular_header_written.add(tabular_fd)
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(snapshot_mode)
    logger.set_snapshot_gap(snapshot_gap)
    logger.set_log_tabular_only(log_tabular_only)
    exp_name = log_dir.split("/")[-1]
    logger.push_prefix("[%s] " % exp_name)

    if git_info is not None:
        code_diff, commit_hash, branch_name = git_info
        if code_diff is not None:
            with open(osp.join(log_dir, "code.diff"), "w") as f:
                f.write(code_diff)
        with open(osp.join(log_dir, "git_info.txt"), "w") as f:
            f.write("git hash: {}".format(commit_hash))
            f.write('\n')
            f.write("git branch name: {}".format(branch_name))
    if script_name is not None:
        with open(osp.join(log_dir, "script_name.txt"), "w") as f:
            f.write(script_name)
    return log_dir
Exemple #17
0
    def pretrain_q_with_bc_data(self):
        """

        :return:
        """
        logger.remove_tabular_output('progress.csv',
                                     relative_to_snapshot_dir=True)
        logger.add_tabular_output('pretrain_q.csv',
                                  relative_to_snapshot_dir=True)

        self.update_policy = False
        # first train only the Q function
        for i in range(self.q_num_pretrain1_steps):
            self.eval_statistics = dict()

            train_data = self.replay_buffer.random_batch(self.bc_batch_size)
            train_data = np_to_pytorch_batch(train_data)
            obs = train_data['observations']
            next_obs = train_data['next_observations']
            # goals = train_data['resampled_goals']
            train_data['observations'] = obs  # torch.cat((obs, goals), dim=1)
            train_data[
                'next_observations'] = next_obs  # torch.cat((next_obs, goals), dim=1)
            self.train_from_torch(train_data, pretrain=True)
            if i % self.pretraining_logging_period == 0:
                stats_with_prefix = add_prefix(self.eval_statistics,
                                               prefix="trainer/")
                logger.record_dict(stats_with_prefix)
                logger.dump_tabular(with_prefix=True, with_timestamp=False)

        self.update_policy = True
        # then train policy and Q function together
        prev_time = time.time()
        for i in range(self.q_num_pretrain2_steps):
            self.eval_statistics = dict()
            if i % self.pretraining_logging_period == 0:
                self._need_to_update_eval_statistics = True
            train_data = self.replay_buffer.random_batch(self.bc_batch_size)
            train_data = np_to_pytorch_batch(train_data)
            obs = train_data['observations']
            next_obs = train_data['next_observations']
            # goals = train_data['resampled_goals']
            train_data['observations'] = obs  # torch.cat((obs, goals), dim=1)
            train_data[
                'next_observations'] = next_obs  # torch.cat((next_obs, goals), dim=1)
            self.train_from_torch(train_data, pretrain=True)

            if i % self.pretraining_logging_period == 0:
                self.eval_statistics["batch"] = i
                self.eval_statistics["epoch_time"] = time.time() - prev_time
                stats_with_prefix = add_prefix(self.eval_statistics,
                                               prefix="trainer/")
                logger.record_dict(stats_with_prefix)
                logger.dump_tabular(with_prefix=True, with_timestamp=False)
                prev_time = time.time()

        logger.remove_tabular_output(
            'pretrain_q.csv',
            relative_to_snapshot_dir=True,
        )
        logger.add_tabular_output(
            'progress.csv',
            relative_to_snapshot_dir=True,
        )

        self._need_to_update_eval_statistics = True
        self.eval_statistics = dict()

        if self.post_pretrain_hyperparams:
            self.set_algorithm_weights(**self.post_pretrain_hyperparams)
Exemple #18
0
    def pretrain_policy_with_bc(
        self,
        policy,
        train_buffer,
        test_buffer,
        steps,
        label="policy",
    ):
        """Given a policy, first get its optimizer, then run the policy on the train buffer, get the
        losses, and back propagate the loss. After training on a batch, test on the test buffer and
        get the statistics

        :param policy:
        :param train_buffer:
        :param test_buffer:
        :param steps:
        :param label:
        :return:
        """
        logger.remove_tabular_output(
            'progress.csv',
            relative_to_snapshot_dir=True,
        )
        logger.add_tabular_output(
            'pretrain_%s.csv' % label,
            relative_to_snapshot_dir=True,
        )

        optimizer = self.optimizers[policy]
        prev_time = time.time()
        for i in range(steps):
            train_policy_loss, train_logp_loss, train_mse_loss, train_stats = self.run_bc_batch(
                train_buffer, policy)
            train_policy_loss = train_policy_loss * self.bc_weight

            optimizer.zero_grad()
            train_policy_loss.backward()
            optimizer.step()

            test_policy_loss, test_logp_loss, test_mse_loss, test_stats = self.run_bc_batch(
                test_buffer, policy)
            test_policy_loss = test_policy_loss * self.bc_weight

            if i % self.pretraining_logging_period == 0:
                stats = {
                    "pretrain_bc/batch":
                    i,
                    "pretrain_bc/Train Logprob Loss":
                    ptu.get_numpy(train_logp_loss),
                    "pretrain_bc/Test Logprob Loss":
                    ptu.get_numpy(test_logp_loss),
                    "pretrain_bc/Train MSE":
                    ptu.get_numpy(train_mse_loss),
                    "pretrain_bc/Test MSE":
                    ptu.get_numpy(test_mse_loss),
                    "pretrain_bc/train_policy_loss":
                    ptu.get_numpy(train_policy_loss),
                    "pretrain_bc/test_policy_loss":
                    ptu.get_numpy(test_policy_loss),
                    "pretrain_bc/epoch_time":
                    time.time() - prev_time,
                }

                logger.record_dict(stats)
                logger.dump_tabular(with_prefix=True, with_timestamp=False)
                pickle.dump(
                    self.policy,
                    open(logger.get_snapshot_dir() + '/bc_%s.pkl' % label,
                         "wb"))
                prev_time = time.time()

        logger.remove_tabular_output(
            'pretrain_%s.csv' % label,
            relative_to_snapshot_dir=True,
        )
        logger.add_tabular_output(
            'progress.csv',
            relative_to_snapshot_dir=True,
        )

        if self.post_bc_pretrain_hyperparams:
            self.set_algorithm_weights(**self.post_bc_pretrain_hyperparams)
def train_vae(
    vae: VAE,
    dset: NpGroundTruthData,
    num_epochs: int = 0,  # Do not pre-train by default
    save_period: int = 5,
    test_p: float = 0.1,  # data proportion to use for test
    vae_trainer_kwargs: Dict = {},
):
    logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True)
    logger.add_tabular_output('vae_progress.csv',
                              relative_to_snapshot_dir=True)

    # Flatten images.
    n = dset.data.shape[0]
    data = dset.data.transpose(0, 3, 2, 1)
    data = data.reshape((n, -1))

    # Un-normalize images.
    if data.dtype != np.uint8:
        assert np.min(data) >= 0.0
        assert np.max(data) <= 1.0
        data = (data * 255).astype(np.uint8)

    # Make sure factors are normalized.
    assert np.min(dset.factors) >= 0.0
    assert np.max(dset.factors) <= 1.0

    # Split into train and test set.
    test_size = int(n * test_p)
    test_data = data[:test_size]
    train_data = data[test_size:]
    train_factors = dset.factors[test_size:]
    test_factors = dset.factors[:test_size]

    logger.get_snapshot_dir()

    trainer = VAETrainer(vae,
                         train_data,
                         test_data,
                         train_factors=train_factors,
                         test_factors=test_factors,
                         **vae_trainer_kwargs)

    for epoch in range(num_epochs):
        should_save_imgs = (epoch % save_period == 0)
        trainer.train_epoch(epoch)
        trainer.test_epoch(epoch, save_reconstruction=should_save_imgs)
        if should_save_imgs:
            trainer.dump_samples(epoch)
        trainer.update_train_weights()

    logger.save_extra_data(vae, 'vae.pkl', mode='pickle')
    logger.remove_tabular_output(
        'vae_progress.csv',
        relative_to_snapshot_dir=True,
    )
    logger.add_tabular_output(
        'progress.csv',
        relative_to_snapshot_dir=True,
    )

    return vae, trainer
Exemple #20
0
def smac_experiment(
    trainer_kwargs=None,
    algo_kwargs=None,
    qf_kwargs=None,
    policy_kwargs=None,
    context_encoder_kwargs=None,
    context_decoder_kwargs=None,
    env_name=None,
    env_params=None,
    path_loader_kwargs=None,
    latent_dim=None,
    policy_class="TanhGaussianPolicy",
    # video/debug
    debug=False,
    use_dummy_encoder=False,
    networks_ignore_context=False,
    use_ground_truth_context=False,
    save_video=False,
    save_video_period=False,
    # Pre-train params
    pretrain_rl=False,
    pretrain_offline_algo_kwargs=None,
    pretrain_buffer_kwargs=None,
    load_buffer_kwargs=None,
    saved_tasks_path=None,
    macaw_format_base_path=None,  # overrides saved_tasks_path and load_buffer_kwargs
    load_macaw_buffer_kwargs=None,
    train_task_idxs=None,
    eval_task_idxs=None,
    relabel_offline_dataset=False,
    skip_initial_data_collection_if_pretrained=False,
    relabel_kwargs=None,
    # PEARL
    n_train_tasks=0,
    n_eval_tasks=0,
    use_next_obs_in_context=False,
    tags=None,
    online_trainer_kwargs=None,
):
    if not skip_initial_data_collection_if_pretrained:
        raise NotImplementedError("deprecated! make sure to skip it!")
    if relabel_kwargs is None:
        relabel_kwargs = {}
    del tags
    pretrain_buffer_kwargs = pretrain_buffer_kwargs or {}
    context_decoder_kwargs = context_decoder_kwargs or {}
    pretrain_offline_algo_kwargs = pretrain_offline_algo_kwargs or {}
    online_trainer_kwargs = online_trainer_kwargs or {}
    register_pearl_envs()
    env_params = env_params or {}
    context_encoder_kwargs = context_encoder_kwargs or {}
    trainer_kwargs = trainer_kwargs or {}
    path_loader_kwargs = path_loader_kwargs or {}
    load_macaw_buffer_kwargs = load_macaw_buffer_kwargs or {}

    base_env = ENVS[env_name](**env_params)
    if saved_tasks_path:
        task_data = load_local_or_remote_file(saved_tasks_path,
                                              file_type='joblib')
        tasks = task_data['tasks']
        train_task_idxs = task_data['train_task_indices']
        eval_task_idxs = task_data['eval_task_indices']
        base_env.tasks = tasks
    elif macaw_format_base_path is not None:
        tasks = pickle.load(
            open('{}/tasks.pkl'.format(macaw_format_base_path), 'rb'))
        base_env.tasks = tasks
    else:
        tasks = base_env.tasks
        task_indices = base_env.get_all_task_idx()
        train_task_idxs = list(task_indices[:n_train_tasks])
        eval_task_idxs = list(task_indices[-n_eval_tasks:])
    if hasattr(base_env, 'task_to_vec'):
        train_tasks = [base_env.task_to_vec(tasks[i]) for i in train_task_idxs]
        eval_tasks = [base_env.task_to_vec(tasks[i]) for i in eval_task_idxs]
    else:
        train_tasks = [tasks[i] for i in train_task_idxs]
        eval_tasks = [tasks[i] for i in eval_task_idxs]
    if use_ground_truth_context:
        latent_dim = len(train_tasks[0])
    expl_env = NormalizedBoxEnv(base_env)

    reward_dim = 1

    if debug:
        algo_kwargs['max_path_length'] = 50
        algo_kwargs['batch_size'] = 5
        algo_kwargs['num_epochs'] = 5
        algo_kwargs['num_eval_steps_per_epoch'] = 100
        algo_kwargs['num_expl_steps_per_train_loop'] = 100
        algo_kwargs['num_trains_per_train_loop'] = 10
        algo_kwargs['min_num_steps_before_training'] = 100

    obs_dim = expl_env.observation_space.low.size
    action_dim = expl_env.action_space.low.size

    if use_next_obs_in_context:
        context_encoder_input_dim = 2 * obs_dim + action_dim + reward_dim
    else:
        context_encoder_input_dim = obs_dim + action_dim + reward_dim
    context_encoder_output_dim = latent_dim * 2

    def create_qf():
        return ConcatMlp(input_size=obs_dim + action_dim + latent_dim,
                         output_size=1,
                         **qf_kwargs)

    qf1 = create_qf()
    qf2 = create_qf()
    target_qf1 = create_qf()
    target_qf2 = create_qf()

    if isinstance(policy_class, str):
        policy_class = policy_class_from_str(policy_class)
    policy = policy_class(
        obs_dim=obs_dim + latent_dim,
        action_dim=action_dim,
        **policy_kwargs,
    )
    encoder_class = DummyMlpEncoder if use_dummy_encoder else MlpEncoder
    context_encoder = encoder_class(
        input_size=context_encoder_input_dim,
        output_size=context_encoder_output_dim,
        hidden_sizes=[200, 200, 200],
        use_ground_truth_context=use_ground_truth_context,
        **context_encoder_kwargs)
    context_decoder = MlpDecoder(input_size=obs_dim + action_dim + latent_dim,
                                 output_size=1,
                                 **context_decoder_kwargs)
    reward_predictor = context_decoder
    agent = SmacAgent(
        latent_dim,
        context_encoder,
        policy,
        reward_predictor,
        use_next_obs_in_context=use_next_obs_in_context,
        _debug_ignore_context=networks_ignore_context,
        _debug_use_ground_truth_context=use_ground_truth_context,
    )
    trainer = SmacTrainer(
        agent=agent,
        env=expl_env,
        latent_dim=latent_dim,
        qf1=qf1,
        qf2=qf2,
        target_qf1=target_qf1,
        target_qf2=target_qf2,
        reward_predictor=reward_predictor,
        context_encoder=context_encoder,
        context_decoder=context_decoder,
        _debug_ignore_context=networks_ignore_context,
        _debug_use_ground_truth_context=use_ground_truth_context,
        **trainer_kwargs)
    algorithm = MetaRLAlgorithm(
        agent=agent,
        env=expl_env,
        trainer=trainer,
        train_task_indices=train_task_idxs,
        eval_task_indices=eval_task_idxs,
        train_tasks=train_tasks,
        eval_tasks=eval_tasks,
        use_next_obs_in_context=use_next_obs_in_context,
        use_ground_truth_context=use_ground_truth_context,
        env_info_sizes=get_env_info_sizes(expl_env),
        **algo_kwargs)

    if macaw_format_base_path:
        load_macaw_buffer_onto_algo(algo=algorithm,
                                    base_directory=macaw_format_base_path,
                                    train_task_idxs=train_task_idxs,
                                    **load_macaw_buffer_kwargs)
    elif load_buffer_kwargs:
        load_buffer_onto_algo(algorithm, **load_buffer_kwargs)
    if relabel_offline_dataset:
        relabel_offline_data(algorithm,
                             tasks=tasks,
                             env=expl_env.wrapped_env,
                             **relabel_kwargs)
    if path_loader_kwargs:
        replay_buffer = algorithm.replay_buffer.task_buffers[0]
        enc_replay_buffer = algorithm.enc_replay_buffer.task_buffers[0]
        demo_test_buffer = EnvReplayBuffer(env=expl_env,
                                           **pretrain_buffer_kwargs)
        path_loader = MDPPathLoader(trainer,
                                    replay_buffer=replay_buffer,
                                    demo_train_buffer=enc_replay_buffer,
                                    demo_test_buffer=demo_test_buffer,
                                    **path_loader_kwargs)
        path_loader.load_demos()

    if pretrain_rl:
        eval_pearl_fn = EvalPearl(algorithm, train_task_idxs, eval_task_idxs)
        pretrain_algo = OfflineMetaRLAlgorithm(
            meta_replay_buffer=algorithm.meta_replay_buffer,
            replay_buffer=algorithm.replay_buffer,
            task_embedding_replay_buffer=algorithm.enc_replay_buffer,
            trainer=trainer,
            train_tasks=train_task_idxs,
            extra_eval_fns=[eval_pearl_fn],
            use_meta_learning_buffer=algorithm.use_meta_learning_buffer,
            **pretrain_offline_algo_kwargs)
        pretrain_algo.to(ptu.device)
        logger.remove_tabular_output('progress.csv',
                                     relative_to_snapshot_dir=True)
        logger.add_tabular_output('pretrain.csv',
                                  relative_to_snapshot_dir=True)
        pretrain_algo.train()
        logger.remove_tabular_output('pretrain.csv',
                                     relative_to_snapshot_dir=True)
        logger.add_tabular_output(
            'progress.csv',
            relative_to_snapshot_dir=True,
        )
        if skip_initial_data_collection_if_pretrained:
            algorithm.num_initial_steps = 0

    algorithm.trainer.configure(**online_trainer_kwargs)
    algorithm.to(ptu.device)
    algorithm.train()
Exemple #21
0
    def pretrain_policy_with_bc(self):
        if self.buffer_for_bc_training == "demos":
            self.bc_training_buffer = self.demo_train_buffer
            self.bc_test_buffer = self.demo_test_buffer
        elif self.buffer_for_bc_training == "replay_buffer":
            self.bc_training_buffer = self.replay_buffer.train_replay_buffer
            self.bc_test_buffer = self.replay_buffer.validation_replay_buffer
        else:
            self.bc_training_buffer = None
            self.bc_test_buffer = None

        if self.load_policy_path:
            self.policy = load_local_or_remote_file(self.load_policy_path)
            ptu.copy_model_params_from_to(self.policy, self.target_policy)
            return

        logger.remove_tabular_output('progress.csv',
                                     relative_to_snapshot_dir=True)
        logger.add_tabular_output('pretrain_policy.csv',
                                  relative_to_snapshot_dir=True)
        if self.do_pretrain_rollouts:
            total_ret = self.do_rollouts()
            print("INITIAL RETURN", total_ret / 20)

        prev_time = time.time()
        for i in range(self.bc_num_pretrain_steps):
            train_policy_loss, train_logp_loss, train_mse_loss, train_log_std = self.run_bc_batch(
                self.demo_train_buffer, self.policy)
            train_policy_loss = train_policy_loss * self.bc_weight

            self.policy_optimizer.zero_grad()
            train_policy_loss.backward()
            self.policy_optimizer.step()

            test_policy_loss, test_logp_loss, test_mse_loss, test_log_std = self.run_bc_batch(
                self.demo_test_buffer, self.policy)
            test_policy_loss = test_policy_loss * self.bc_weight

            if self.do_pretrain_rollouts and i % self.pretraining_env_logging_period == 0:
                total_ret = self.do_rollouts()
                print("Return at step {} : {}".format(i, total_ret / 20))

            if i % self.pretraining_logging_period == 0:
                stats = {
                    "pretrain_bc/batch":
                    i,
                    "pretrain_bc/Train Logprob Loss":
                    ptu.get_numpy(train_logp_loss),
                    "pretrain_bc/Test Logprob Loss":
                    ptu.get_numpy(test_logp_loss),
                    "pretrain_bc/Train MSE":
                    ptu.get_numpy(train_mse_loss),
                    "pretrain_bc/Test MSE":
                    ptu.get_numpy(test_mse_loss),
                    "pretrain_bc/train_policy_loss":
                    ptu.get_numpy(train_policy_loss),
                    "pretrain_bc/test_policy_loss":
                    ptu.get_numpy(test_policy_loss),
                    "pretrain_bc/epoch_time":
                    time.time() - prev_time,
                }

                if self.do_pretrain_rollouts:
                    stats["pretrain_bc/avg_return"] = total_ret / 20

                logger.record_dict(stats)
                logger.dump_tabular(with_prefix=True, with_timestamp=False)
                pickle.dump(self.policy,
                            open(logger.get_snapshot_dir() + '/bc.pkl', "wb"))
                prev_time = time.time()

        logger.remove_tabular_output(
            'pretrain_policy.csv',
            relative_to_snapshot_dir=True,
        )
        logger.add_tabular_output(
            'progress.csv',
            relative_to_snapshot_dir=True,
        )

        ptu.copy_model_params_from_to(self.policy, self.target_policy)

        if self.post_bc_pretrain_hyperparams:
            self.set_algorithm_weights(**self.post_bc_pretrain_hyperparams)
def train_vae_and_update_variant(variant):  # actually pretrain vae and ROLL.
    skewfit_variant = variant['skewfit_variant']
    train_vae_variant = variant['train_vae_variant']

    # prepare the background subtractor needed to perform segmentation
    if 'unet' in skewfit_variant['segmentation_method']:
        print("training opencv background model!")
        v = train_vae_variant['generate_lstm_dataset_kwargs']
        env_id = v.get('env_id', None)
        env_id_invis = invisiable_env_id[env_id]
        import gym
        import multiworld
        multiworld.register_all_envs()
        obj_invisible_env = gym.make(env_id_invis)
        init_camera = v.get('init_camera', None)

        presampled_goals = None
        if skewfit_variant.get("presampled_goals_path") is not None:
            presampled_goals = load_local_or_remote_file(
                skewfit_variant['presampled_goals_path']).item()
            print("presampled goal path is: ",
                  skewfit_variant['presampled_goals_path'])

        obj_invisible_env = ImageEnv(
            obj_invisible_env,
            v.get('imsize'),
            init_camera=init_camera,
            transpose=True,
            normalize=True,
            presampled_goals=presampled_goals,
        )

        train_num = 2000 if 'Push' in env_id else 4000
        train_bgsb(obj_invisible_env, train_num=train_num)

    if skewfit_variant.get('vae_path', None) is None:  # train new vaes
        logger.remove_tabular_output('progress.csv',
                                     relative_to_snapshot_dir=True)

        vaes, vae_train_datas, vae_test_datas = train_vae(
            train_vae_variant,
            skewfit_variant=skewfit_variant,
            return_data=True)  # one original vae, one segmented ROLL.
        if skewfit_variant.get('save_vae_data', False):
            skewfit_variant['vae_train_data'] = vae_train_datas
            skewfit_variant['vae_test_data'] = vae_test_datas

        logger.add_tabular_output(
            'progress.csv',
            relative_to_snapshot_dir=True,
        )
        skewfit_variant['vae_path'] = vaes  # just pass the VAE directly
    else:  # load pre-trained vaes
        print("load pretrain scene-/objce-VAE from: {}".format(
            skewfit_variant['vae_path']))
        data = torch.load(osp.join(skewfit_variant['vae_path'], 'params.pkl'))
        vae_original = data['vae_original']
        vae_segmented = data['lstm_segmented']
        skewfit_variant['vae_path'] = [vae_segmented, vae_original]

        generate_vae_dataset_fctn = train_vae_variant.get(
            'generate_vae_data_fctn', generate_vae_dataset)
        generate_lstm_dataset_fctn = train_vae_variant.get(
            'generate_lstm_data_fctn')
        assert generate_lstm_dataset_fctn is not None, "Must provide a custom generate lstm pretraining dataset function!"

        train_data_lstm, test_data_lstm, info_lstm = generate_lstm_dataset_fctn(
            train_vae_variant['generate_lstm_dataset_kwargs'],
            segmented=True,
            segmentation_method=skewfit_variant['segmentation_method'])

        train_data_ori, test_data_ori, info_ori = generate_vae_dataset_fctn(
            train_vae_variant['generate_vae_dataset_kwargs'])

        train_datas = [train_data_lstm, train_data_ori]
        test_datas = [test_data_lstm, test_data_ori]

        if skewfit_variant.get('save_vae_data', False):
            skewfit_variant['vae_train_data'] = train_datas
            skewfit_variant['vae_test_data'] = test_datas
def train_vae(
        variant,
        return_data=False,
        skewfit_variant=None):  # acutally train both the vae and the lstm
    from rlkit.util.ml_util import PiecewiseLinearSchedule
    from rlkit.torch.vae.conv_vae import (
        ConvVAE, )
    import rlkit.torch.vae.conv_vae as conv_vae
    import ROLL.LSTM_model as LSTM_model
    from ROLL.LSTM_model import ConvLSTM2
    from ROLL.LSTM_trainer import ConvLSTMTrainer
    from rlkit.torch.vae.vae_trainer import ConvVAETrainer
    import rlkit.torch.pytorch_util as ptu
    from rlkit.pythonplusplus import identity
    import torch
    seg_pretrain = variant['seg_pretrain']
    ori_pretrain = variant['ori_pretrain']
    generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn',
                                            generate_vae_dataset)
    generate_lstm_dataset_fctn = variant.get('generate_lstm_data_fctn')
    assert generate_lstm_dataset_fctn is not None, "Must provide a custom generate lstm pretraining dataset function!"

    train_data_lstm, test_data_lstm, info_lstm = generate_lstm_dataset_fctn(
        variant['generate_lstm_dataset_kwargs'],
        segmented=True,
        segmentation_method=skewfit_variant['segmentation_method'])

    train_data_ori, test_data_ori, info_ori = generate_vae_dataset_fctn(
        variant['generate_vae_dataset_kwargs'])

    if 'beta_schedule_kwargs' in variant:
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None
    if variant.get('decoder_activation', None) == 'sigmoid':
        decoder_activation = torch.nn.Sigmoid()
    else:
        decoder_activation = identity

    architecture = variant['vae_kwargs'].get('architecture', None)
    if not architecture and variant.get('imsize') == 84:
        architecture = conv_vae.imsize84_default_architecture
    elif not architecture and variant.get('imsize') == 48:
        architecture = conv_vae.imsize48_default_architecture
    variant['vae_kwargs']['architecture'] = architecture
    variant['vae_kwargs']['imsize'] = variant.get('imsize')

    architecture = variant['lstm_kwargs'].get('architecture', None)
    if not architecture and variant.get('imsize') == 84:
        architecture = None  # TODO LSTM: wrap a 84 lstm architecutre
    elif not architecture and variant.get('imsize') == 48:
        architecture = LSTM_model.imsize48_default_architecture
    variant['lstm_kwargs']['architecture'] = architecture
    variant['lstm_kwargs']['imsize'] = variant.get('imsize')

    train_datas = [
        train_data_lstm,
        train_data_ori,
    ]
    test_datas = [
        test_data_lstm,
        test_data_ori,
    ]
    names = [
        'lstm_seg_pretrain',
        'vae_ori_pretrain',
    ]
    vaes = []
    env_id = variant['generate_lstm_dataset_kwargs'].get('env_id')
    assert env_id is not None
    lstm_pretrain_vae_only = variant.get('lstm_pretrain_vae_only', False)

    for idx in range(2):
        train_data, test_data, name = train_datas[idx], test_datas[idx], names[
            idx]

        logger.add_tabular_output('{}_progress.csv'.format(name),
                                  relative_to_snapshot_dir=True)

        if idx == 1:  # train the original vae
            representation_size = variant.get(
                "vae_representation_size", variant.get('representation_size'))
            beta = variant.get('vae_beta', variant.get('beta'))
            m = ConvVAE(representation_size,
                        decoder_output_activation=decoder_activation,
                        **variant['vae_kwargs'])
            t = ConvVAETrainer(train_data,
                               test_data,
                               m,
                               beta=beta,
                               beta_schedule=beta_schedule,
                               **variant['algo_kwargs'])
        else:  # train the segmentation lstm
            lstm_version = variant.get('lstm_version', 2)
            if lstm_version == 2:
                lstm_class = ConvLSTM2

            representation_size = variant.get(
                "lstm_representation_size", variant.get('representation_size'))
            beta = variant.get('lstm_beta', variant.get('beta'))
            m = lstm_class(representation_size,
                           decoder_output_activation=decoder_activation,
                           **variant['lstm_kwargs'])
            t = ConvLSTMTrainer(train_data,
                                test_data,
                                m,
                                beta=beta,
                                beta_schedule=beta_schedule,
                                **variant['algo_kwargs'])

        m.to(ptu.device)

        vaes.append(m)

        print("test data len: ", len(test_data))
        print("train data len: ", len(train_data))

        save_period = variant['save_period']

        pjhome = os.environ['PJHOME']
        if env_id == 'SawyerPushHurdle-v0' and osp.exists(
                osp.join(
                    pjhome,
                    'data/local/pre-train-lstm', '{}-{}-{}-0.3-0.5.npy'.format(
                        'SawyerPushHurdle-v0', 'seg-color', '500'))):
            data_file_path = osp.join(
                pjhome, 'data/local/pre-train-lstm',
                '{}-{}-{}-0.3-0.5.npy'.format(env_id, 'seg-color', 500))
            puck_pos_path = osp.join(
                pjhome, 'data/local/pre-train-lstm',
                '{}-{}-{}-0.3-0.5-puck-pos.npy'.format(env_id, 'seg-color',
                                                       500))
            all_data = np.load(data_file_path)
            puck_pos = np.load(puck_pos_path)
            all_data = normalize_image(all_data.copy())
            obj_states = puck_pos
        else:
            all_data = np.concatenate([train_data_lstm, test_data_lstm],
                                      axis=0)
            all_data = normalize_image(all_data.copy())
            obj_states = info_lstm['obj_state']

        obj = 'door' if 'Door' in env_id else 'puck'

        num_epochs = variant['num_lstm_epochs'] if idx == 0 else variant[
            'num_vae_epochs']

        if (idx == 0 and seg_pretrain) or (idx == 1 and ori_pretrain):
            for epoch in range(num_epochs):
                should_save_imgs = (epoch % save_period == 0)
                if idx == 0:  # only LSTM trainer has 'only_train_vae' argument
                    t.train_epoch(epoch, only_train_vae=lstm_pretrain_vae_only)
                    t.test_epoch(epoch,
                                 save_reconstruction=should_save_imgs,
                                 save_prefix='r_' + name,
                                 only_train_vae=lstm_pretrain_vae_only)
                else:
                    t.train_epoch(epoch)
                    t.test_epoch(
                        epoch,
                        save_reconstruction=should_save_imgs,
                        save_prefix='r_' + name,
                    )

                if should_save_imgs:
                    t.dump_samples(epoch, save_prefix='s_' + name)

                    if idx == 0:
                        compare_latent_distance(
                            m,
                            all_data,
                            obj_states,
                            obj_name=obj,
                            save_dir=logger.get_snapshot_dir(),
                            save_name='lstm_latent_distance_{}.png'.format(
                                epoch))
                        test_lstm_traj(
                            env_id,
                            m,
                            save_path=logger.get_snapshot_dir(),
                            save_name='lstm_test_traj_{}.png'.format(epoch))
                        test_masked_traj_lstm(
                            env_id,
                            m,
                            save_dir=logger.get_snapshot_dir(),
                            save_name='masked_test_{}.png'.format(epoch))

                t.update_train_weights()

            logger.save_extra_data(m, '{}.pkl'.format(name), mode='pickle')

        logger.remove_tabular_output('{}_progress.csv'.format(name),
                                     relative_to_snapshot_dir=True)

        if idx == 0 and variant.get("only_train_lstm", False):
            exit()

    if return_data:
        return vaes, train_datas, test_datas
    return m
Exemple #24
0
def get_n_train_vae(latent_dim,
                    env,
                    vae_train_epochs,
                    num_image_examples,
                    vae_kwargs,
                    vae_trainer_kwargs,
                    vae_architecture,
                    vae_save_period=10,
                    vae_test_p=.9,
                    decoder_activation='sigmoid',
                    vae_class='VAE',
                    **kwargs):
    env.goal_sampling_mode = 'test'
    image_examples = unnormalize_image(
        env.sample_goals(num_image_examples)['desired_goal'])
    n = int(num_image_examples * vae_test_p)
    train_dataset = ImageObservationDataset(image_examples[:n, :])
    test_dataset = ImageObservationDataset(image_examples[n:, :])

    if decoder_activation == 'sigmoid':
        decoder_activation = torch.nn.Sigmoid()

    vae_class = vae_class.lower()
    if vae_class == 'VAE'.lower():
        vae_class = ConvVAE
    elif vae_class == 'SpatialVAE'.lower():
        vae_class = SpatialAutoEncoder
    else:
        raise RuntimeError("Invalid VAE Class: {}".format(vae_class))

    vae = vae_class(latent_dim,
                    architecture=vae_architecture,
                    decoder_output_activation=decoder_activation,
                    **vae_kwargs)

    trainer = ConvVAETrainer(vae, **vae_trainer_kwargs)

    logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True)
    logger.add_tabular_output('vae_progress.csv',
                              relative_to_snapshot_dir=True)
    for epoch in range(vae_train_epochs):
        should_save_imgs = (epoch % vae_save_period == 0)
        trainer.train_epoch(epoch, train_dataset)
        trainer.test_epoch(epoch, test_dataset)

        if should_save_imgs:
            trainer.dump_reconstructions(epoch)
            trainer.dump_samples(epoch)
        stats = trainer.get_diagnostics()
        for k, v in stats.items():
            logger.record_tabular(k, v)

        logger.dump_tabular()
        trainer.end_epoch(epoch)

        if epoch % 50 == 0:
            logger.save_itr_params(epoch, vae)
    logger.save_extra_data(vae, 'vae.pkl', mode='pickle')
    logger.remove_tabular_output('vae_progress.csv',
                                 relative_to_snapshot_dir=True)
    logger.add_tabular_output('progress.csv', relative_to_snapshot_dir=True)
    return vae
Exemple #25
0
def run_experiment(argv):
    default_log_dir = config.LOCAL_LOG_DIR
    now = datetime.datetime.now(dateutil.tz.tzlocal())

    # avoid name clashes when running distributed jobs
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')

    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--n_parallel',
        type=int,
        default=1,
        help=
        'Number of parallel workers to perform rollouts. 0 => don\'t start any workers'
    )
    parser.add_argument('--exp_name',
                        type=str,
                        default=default_exp_name,
                        help='Name of the experiment.')
    parser.add_argument('--log_dir',
                        type=str,
                        default=None,
                        help='Path to save the log and iteration snapshot.')
    parser.add_argument('--snapshot_mode',
                        type=str,
                        default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                        '(all iterations will be saved), "last" (only '
                        'the last iteration will be saved), "gap" (every'
                        '`snapshot_gap` iterations are saved), or "none" '
                        '(do not save snapshots)')
    parser.add_argument('--snapshot_gap',
                        type=int,
                        default=1,
                        help='Gap between snapshot iterations.')
    parser.add_argument('--tabular_log_file',
                        type=str,
                        default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file',
                        type=str,
                        default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file',
                        type=str,
                        default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--variant_log_file',
                        type=str,
                        default='variant.json',
                        help='Name of the variant log file (in json).')
    parser.add_argument(
        '--resume_from',
        type=str,
        default=None,
        help='Name of the pickle file to resume experiment from.')
    parser.add_argument('--plot',
                        type=ast.literal_eval,
                        default=False,
                        help='Whether to plot the iteration results')
    parser.add_argument(
        '--log_tabular_only',
        type=ast.literal_eval,
        default=False,
        help=
        'Whether to only print the tabular log information (in a horizontal format)'
    )
    parser.add_argument('--seed', type=int, help='Random seed for numpy')
    parser.add_argument('--args_data',
                        type=str,
                        help='Pickled data for stub objects')
    parser.add_argument('--variant_data',
                        type=str,
                        help='Pickled data for variant configuration')
    parser.add_argument('--use_cloudpickle',
                        type=ast.literal_eval,
                        default=False)
    parser.add_argument('--code_diff',
                        type=str,
                        help='A string of the code diff to save.')
    parser.add_argument('--commit_hash',
                        type=str,
                        help='A string of the commit hash')
    parser.add_argument('--script_name',
                        type=str,
                        help='Name of the launched script')

    args = parser.parse_args(argv[1:])

    if args.seed is not None:
        set_seed(args.seed)

    if args.n_parallel > 0:
        from rllab.sampler import parallel_sampler
        parallel_sampler.initialize(n_parallel=args.n_parallel)
        if args.seed is not None:
            parallel_sampler.set_seed(args.seed)

    if args.plot:
        from rllab.plotter import plotter
        plotter.init_worker()

    if args.log_dir is None:
        log_dir = osp.join(default_log_dir, args.exp_name)
    else:
        log_dir = args.log_dir
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    if args.variant_data is not None:
        variant_data = pickle.loads(base64.b64decode(args.variant_data))
        variant_log_file = osp.join(log_dir, args.variant_log_file)
        logger.log_variant(variant_log_file, variant_data)
    else:
        variant_data = None

    if not args.use_cloudpickle:
        raise NotImplementedError("Not supporting non-cloud-pickle")

    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_snapshot_gap(args.snapshot_gap)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)
    """
    Save information for code reproducibility.
    """
    if args.code_diff is not None:
        code_diff_str = cloudpickle.loads(base64.b64decode(args.code_diff))
        with open(osp.join(log_dir, "code.diff"), "w") as f:
            f.write(code_diff_str)
    if args.commit_hash is not None:
        with open(osp.join(log_dir, "commit_hash.txt"), "w") as f:
            f.write(args.commit_hash)
    if args.script_name is not None:
        with open(osp.join(log_dir, "script_name.txt"), "w") as f:
            f.write(args.script_name)

    if args.resume_from is not None:
        data = joblib.load(args.resume_from)
        assert 'algo' in data
        algo = data['algo']
        algo.train()
    else:
        # read from stdin
        if args.use_cloudpickle:
            method_call = cloudpickle.loads(base64.b64decode(args.args_data))
            method_call(variant_data)
        else:
            data = pickle.loads(base64.b64decode(args.args_data))
            maybe_iter = concretize(data)
            if is_iterable(maybe_iter):
                for _ in maybe_iter:
                    pass

    logger.set_snapshot_mode(prev_mode)
    logger.set_snapshot_dir(prev_snapshot_dir)
    logger.remove_tabular_output(tabular_log_file)
    logger.remove_text_output(text_log_file)
    logger.pop_prefix()