示例#1
0
    def test_epoch(
        self,
        epoch,
    ):
        self.model.eval()
        val_losses = []
        per_dim_losses = np.zeros((self.num_batches, self.y_train.shape[1]))
        for batch in range(self.num_batches):
            inputs_np, labels_np = self.random_batch(
                self.X_test, self.y_test, batch_size=self.batch_size)
            inputs, labels = ptu.Variable(
                ptu.from_numpy(inputs_np)), ptu.Variable(
                    ptu.from_numpy(labels_np))
            outputs = self.model(inputs)
            loss = self.criterion(outputs, labels)
            val_losses.append(loss.data[0])
            per_dim_loss = np.mean(np.power(ptu.get_numpy(outputs - labels),
                                            2),
                                   axis=0)
            per_dim_losses[batch] = per_dim_loss

        logger.record_tabular("test/epoch", epoch)
        logger.record_tabular("test/loss", np.mean(np.array(val_losses)))
        for i in range(self.y_train.shape[1]):
            logger.record_tabular("test/dim " + str(i) + " loss",
                                  np.mean(per_dim_losses[:, i]))
        logger.dump_tabular()
def simulate_policy(args):
    data = joblib.load(args.file)
    model = data['model']
    env = data['env']
    orig_policy = data['mpc_controller']
    print("Policy loaded")
    if args.pause:
        import ipdb
        ipdb.set_trace()
    policy = GradientBasedMPCController(
        env,
        model,
        mpc_horizon=1,
        num_grad_steps=10,
        learning_rate=1e-1,
        warm_start=False,
    )
    while True:
        path = rollout(
            env,
            policy,
            orig_policy,
            max_path_length=args.H,
            animated=True,
        )
        if hasattr(env, "log_diagnostics"):
            env.log_diagnostics([path])
        logger.dump_tabular()
示例#3
0
def experiment(variant):
    num_rollouts = variant['num_rollouts']
    H = variant['H']
    render = variant['render']
    data = joblib.load(variant['qf_path'])
    qf = data['qf']
    env = data['env']
    qf_policy = data['policy']
    if ptu.gpu_enabled():
        qf.to(ptu.device)
        qf_policy.to(ptu.device)
    policy_class = variant['policy_class']
    if policy_class == StateOnlySdqBasedSqpOcPolicy:
        policy = policy_class(qf, env, qf_policy, **variant['policy_params'])
    else:
        policy = policy_class(qf, env, **variant['policy_params'])
    paths = []
    for _ in range(num_rollouts):
        goal = env.sample_goal_for_rollout()
        path = multitask_rollout(
            env,
            policy,
            goal,
            discount=variant['discount'],
            max_path_length=H,
            animated=render,
        )
        paths.append(path)
    env.log_diagnostics(paths)
    logger.dump_tabular(with_timestamp=False)
示例#4
0
    def test_epoch(self, epoch, save_network=True, batches=100):
        self.model.eval()
        mses = []
        losses = []
        for batch_idx in range(batches):
            data = self.get_batch(train=False)
            z = data["z"]
            z_proj = data['z_proj']
            z_proj_hat = self.model(z)
            mse = self.mse_loss(z_proj_hat, z_proj)
            loss = mse

            mses.append(mse.data[0])
            losses.append(loss.data[0])

        logger.record_tabular("test/epoch", epoch)
        logger.record_tabular("test/MSE", np.mean(mses))
        logger.record_tabular("test/loss", np.mean(losses))

        logger.dump_tabular()
        if save_network:
            logger.save_itr_params(epoch,
                                   self.model,
                                   prefix='reproj',
                                   save_anyway=True)
示例#5
0
def experiment(variant):
    num_rollouts = variant['num_rollouts']
    H = variant['H']
    render = variant['render']
    data = joblib.load(variant['qf_path'])
    policy_params = variant['policy_params']
    if 'model' in data:
        model = data['model']
    else:
        qf = data['qf']
        model = ModelExtractor(qf)
        policy_params['model_learns_deltas'] = False
    env = data['env']
    if ptu.gpu_enabled():
        model.to(ptu.device)
    policy = variant['policy_class'](
        model,
        env,
        **policy_params
    )
    paths = []
    for _ in range(num_rollouts):
        goal = env.sample_goal_for_rollout()
        path = multitask_rollout(
            env,
            policy,
            goal,
            discount=0,
            max_path_length=H,
            animated=render,
        )
        paths.append(path)
    env.log_diagnostics(paths)
    logger.dump_tabular(with_timestamp=False)
示例#6
0
    def evaluate(self, epoch):
        """
        Perform evaluation for this algorithm.

        :param epoch: The epoch number.
        """
        statistics = OrderedDict()

        train_batch = self.get_batch()
        statistics.update(self._statistics_from_batch(train_batch, "Train"))

        logger.log("Collecting samples for evaluation")
        test_paths = self._sample_eval_paths()
        statistics.update(
            get_generic_path_information(
                test_paths,
                stat_prefix="Test",
            ))
        statistics.update(self._statistics_from_paths(test_paths, "Test"))
        average_returns = get_average_returns(test_paths)
        statistics['AverageReturn'] = average_returns

        statistics['Epoch'] = epoch

        for key, value in statistics.items():
            logger.record_tabular(key, value)

        self.env.log_diagnostics(test_paths)
        logger.dump_tabular(with_prefix=False, with_timestamp=False)
def train_vae(variant, return_data=False):
    from railrl.misc.ml_util import PiecewiseLinearSchedule
    from railrl.torch.vae.vae_trainer import ConvVAETrainer
    from railrl.core import logger
    beta = variant["beta"]
    use_linear_dynamics = variant.get('use_linear_dynamics', False)
    generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn',
                                            generate_vae_dataset)
    variant['generate_vae_dataset_kwargs'][
        'use_linear_dynamics'] = use_linear_dynamics
    train_dataset, test_dataset, info = generate_vae_dataset_fctn(
        variant['generate_vae_dataset_kwargs'])
    if use_linear_dynamics:
        action_dim = train_dataset.data['actions'].shape[2]
    else:
        action_dim = 0
    model = get_vae(variant, action_dim)

    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if 'beta_schedule_kwargs' in variant:
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None

    vae_trainer_class = variant.get('vae_trainer_class', ConvVAETrainer)
    trainer = vae_trainer_class(model,
                                beta=beta,
                                beta_schedule=beta_schedule,
                                **variant['algo_kwargs'])
    save_period = variant['save_period']

    dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False)
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        trainer.train_epoch(epoch, train_dataset)
        trainer.test_epoch(epoch, test_dataset)

        if should_save_imgs:
            trainer.dump_reconstructions(epoch)
            trainer.dump_samples(epoch)
            if dump_skew_debug_plots:
                trainer.dump_best_reconstruction(epoch)
                trainer.dump_worst_reconstruction(epoch)
                trainer.dump_sampling_histogram(epoch)

        stats = trainer.get_diagnostics()
        for k, v in stats.items():
            logger.record_tabular(k, v)
        logger.dump_tabular()
        trainer.end_epoch(epoch)

        if epoch % 50 == 0:
            logger.save_itr_params(epoch, model)
    logger.save_extra_data(model, 'vae.pkl', mode='pickle')
    if return_data:
        return model, train_dataset, test_dataset
    return model
示例#8
0
    def test_epoch(self, epoch, save_vae=True, **kwargs):
        self.model.eval()
        losses = []
        kles = []
        zs = []

        recon_logging_dict = {
            'MSE': [],
            'WSE': [],
        }
        for k in self.extra_recon_logging:
            recon_logging_dict[k] = []

        beta = self.beta_schedule.get_value(epoch)
        for batch_idx in range(100):
            data = self.get_batch(train=False)
            obs = data['obs']
            next_obs = data['next_obs']
            actions = data['actions']
            recon_batch, mu, logvar = self.model(next_obs)
            mse = self.logprob(recon_batch, next_obs)
            wse = self.logprob(recon_batch, next_obs, unorm_weights=self.recon_weights)
            for k, idx in self.extra_recon_logging.items():
                recon_loss = self.logprob(recon_batch, next_obs, idx=idx)
                recon_logging_dict[k].append(recon_loss.data[0])
            kle = self.kl_divergence(mu, logvar)
            if self.recon_loss_type == 'mse':
                loss = mse + beta * kle
            elif self.recon_loss_type == 'wse':
                loss = wse + beta * kle
            z_data = ptu.get_numpy(mu.cpu())
            for i in range(len(z_data)):
                zs.append(z_data[i, :])
            losses.append(loss.data[0])
            recon_logging_dict['WSE'].append(wse.data[0])
            recon_logging_dict['MSE'].append(mse.data[0])
            kles.append(kle.data[0])
        zs = np.array(zs)
        self.model.dist_mu = zs.mean(axis=0)
        self.model.dist_std = zs.std(axis=0)

        for k in recon_logging_dict:
            logger.record_tabular("/".join(["test", k]), np.mean(recon_logging_dict[k]))
        logger.record_tabular("test/KL", np.mean(kles))
        logger.record_tabular("test/loss", np.mean(losses))
        logger.record_tabular("beta", beta)

        process = psutil.Process(os.getpid())
        logger.record_tabular("RAM Usage (Mb)", int(process.memory_info().rss / 1000000))

        num_active_dims = 0
        for std in self.model.dist_std:
            if std > 0.15:
                num_active_dims += 1
        logger.record_tabular("num_active_dims", num_active_dims)

        logger.dump_tabular()
        if save_vae:
            logger.save_itr_params(epoch, self.model, prefix='vae', save_anyway=True)  # slow...
def run_task(variant):
    from railrl.core import logger
    print(variant)
    logger.log("Hello from script")
    logger.log("variant: " + str(variant))
    logger.record_tabular("value", 1)
    logger.dump_tabular()
    logger.log("snapshot_dir:", logger.get_snapshot_dir())
示例#10
0
    def pretrain_policy_with_bc(self):
        logger.remove_tabular_output(
            'progress.csv', relative_to_snapshot_dir=True
        )
        logger.add_tabular_output(
            'pretrain_policy.csv', relative_to_snapshot_dir=True
        )
        if self.do_pretrain_rollouts:
            total_ret = self.do_rollouts()
            print("INITIAL RETURN", total_ret/20)

        prev_time = time.time()
        for i in range(self.bc_num_pretrain_steps):
            train_policy_loss, train_logp_loss, train_mse_loss, train_log_std = self.run_bc_batch(self.demo_train_buffer, self.policy)
            train_policy_loss = train_policy_loss * self.bc_weight

            self.policy_optimizer.zero_grad()
            train_policy_loss.backward()
            self.policy_optimizer.step()

            test_policy_loss, test_logp_loss, test_mse_loss, test_log_std = self.run_bc_batch(self.demo_test_buffer, self.policy)
            test_policy_loss = test_policy_loss * self.bc_weight

            if self.do_pretrain_rollouts and i % self.pretraining_env_logging_period == 0:
                total_ret = self.do_rollouts()
                print("Return at step {} : {}".format(i, total_ret/20))

            if i % self.pretraining_logging_period==0:
                stats = {
                "pretrain_bc/batch": i,
                "pretrain_bc/Train Logprob Loss": ptu.get_numpy(train_logp_loss),
                "pretrain_bc/Test Logprob Loss": ptu.get_numpy(test_logp_loss),
                "pretrain_bc/Train MSE": ptu.get_numpy(train_mse_loss),
                "pretrain_bc/Test MSE": ptu.get_numpy(test_mse_loss),
                "pretrain_bc/train_policy_loss": ptu.get_numpy(train_policy_loss),
                "pretrain_bc/test_policy_loss": ptu.get_numpy(test_policy_loss),
                "pretrain_bc/epoch_time":time.time()-prev_time,
                }

                if self.do_pretrain_rollouts:
                    stats["pretrain_bc/avg_return"] = total_ret / 20

                logger.record_dict(stats)
                logger.dump_tabular(with_prefix=True, with_timestamp=False)
                pickle.dump(self.policy, open(logger.get_snapshot_dir() + '/bc.pkl', "wb"))
                prev_time = time.time()

        logger.remove_tabular_output(
            'pretrain_policy.csv',
            relative_to_snapshot_dir=True,
        )
        logger.add_tabular_output(
            'progress.csv',
            relative_to_snapshot_dir=True,
        )

        if self.post_bc_pretrain_hyperparams:
            self.set_algorithm_weights(**self.post_bc_pretrain_hyperparams)
示例#11
0
def simulate_policy(args):
    data = pickle.load(open(args.file, "rb"))
    policy_key = args.policy_type + '/policy'
    if policy_key in data:
        policy = data[policy_key]
    else:
        raise Exception("No policy found in loaded dict. Keys: {}".format(
            data.keys()))

    env_key = args.env_type + '/env'
    if env_key in data:
        env = data[env_key]
    else:
        raise Exception("No environment found in loaded dict. Keys: {}".format(
            data.keys()))

    #robosuite env specific things
    env._wrapped_env.has_renderer = True
    env.reset()
    env.viewer.set_camera(camera_id=0)

    if isinstance(env, RemoteRolloutEnv):
        env = env._wrapped_env
    print("Policy loaded")

    if args.enable_render:
        # some environments need to be reconfigured for visualization
        env.enable_render()
    if args.gpu:
        ptu.set_gpu_mode(True)
    if hasattr(policy, "to"):
        policy.to(ptu.device)
    if hasattr(env, "vae"):
        env.vae.to(ptu.device)

    if args.pause:
        import ipdb
        ipdb.set_trace()
    if isinstance(policy, PyTorchModule):
        policy.train(False)
    paths = []
    while True:
        paths.append(
            deprecated_rollout(
                env,
                policy,
                max_path_length=args.H,
                render=not args.hide,
            ))
        if args.log_diagnostics:
            if hasattr(env, "log_diagnostics"):
                env.log_diagnostics(paths, logger)
            for k, v in eval_util.get_generic_path_information(paths).items():
                logger.record_tabular(k, v)
            logger.dump_tabular()
 def train(self):
     timer.return_global_times = True
     for _ in range(self.num_epochs):
         self._begin_epoch()
         # logger.save_itr_params(self.epoch, self._get_snapshot())
         # timer.stamp('saving')
         log_dict, _ = self._train()
         logger.record_dict(log_dict)
         logger.dump_tabular(with_prefix=True, with_timestamp=False)
         logger.save_itr_params(self.epoch, self._get_snapshot())
         self._end_epoch()
示例#13
0
    def _try_to_eval(self, epoch, eval_paths=None):
        logger.save_extra_data(self.get_extra_data_to_save(epoch))

        params = self.get_epoch_snapshot(epoch)
        logger.save_itr_params(epoch, params)

        if self._can_evaluate():
            self.evaluate(epoch, eval_paths=eval_paths)

            # params = self.get_epoch_snapshot(epoch)
            # logger.save_itr_params(epoch, params)
            table_keys = logger.get_table_key_set()
            if self._old_table_keys is not None:
                assert table_keys == self._old_table_keys, (
                    "Table keys cannot change from iteration to iteration."
                )
            self._old_table_keys = table_keys

            logger.record_tabular(
                "Number of train steps total",
                self._n_train_steps_total,
            )
            logger.record_tabular(
                "Number of env steps total",
                self._n_env_steps_total,
            )
            logger.record_tabular(
                "Number of rollouts total",
                self._n_rollouts_total,
            )

            if self.collection_mode != 'online-parallel':
                times_itrs = gt.get_times().stamps.itrs
                train_time = times_itrs['train'][-1]
                sample_time = times_itrs['sample'][-1]
                if 'eval' in times_itrs:
                    eval_time = times_itrs['eval'][-1] if epoch > 0 else -1
                else:
                    eval_time = -1
                epoch_time = train_time + sample_time + eval_time
                total_time = gt.get_times().total

                logger.record_tabular('Train Time (s)', train_time)
                logger.record_tabular('(Previous) Eval Time (s)', eval_time)
                logger.record_tabular('Sample Time (s)', sample_time)
                logger.record_tabular('Epoch Time (s)', epoch_time)
                logger.record_tabular('Total Train Time (s)', total_time)
            else:
                logger.record_tabular('Epoch Time (s)',
                                      time.time() - self._epoch_start_time)
            logger.record_tabular("Epoch", epoch)
            logger.dump_tabular(with_prefix=False, with_timestamp=False)
        else:
            logger.log("Skipping eval for now.")
示例#14
0
    def pretrain_q_with_bc_data(self):
        logger.remove_tabular_output('progress.csv',
                                     relative_to_snapshot_dir=True)
        logger.add_tabular_output('pretrain_q.csv',
                                  relative_to_snapshot_dir=True)
        self.update_policy = False
        # first train only the Q function
        for i in range(self.q_num_pretrain_steps):
            self.eval_statistics = dict()
            self._need_to_update_eval_statistics = True

            train_data = self.replay_buffer.random_batch(128)
            train_data = np_to_pytorch_batch(train_data)
            obs = train_data['observations']
            next_obs = train_data['next_observations']
            if self.goal_conditioned:
                goals = train_data['resampled_goals']
                train_data['observations'] = torch.cat((obs, goals), dim=1)
                train_data['next_observations'] = torch.cat((next_obs, goals),
                                                            dim=1)
            self.train_from_torch(train_data)

            logger.record_dict(self.eval_statistics)
            logger.dump_tabular(with_prefix=True, with_timestamp=False)

        self.update_policy = True
        # then train policy and Q function together
        for i in range(self.q_num_pretrain_steps):
            self.eval_statistics = dict()
            self._need_to_update_eval_statistics = True

            train_data = self.replay_buffer.random_batch(128)
            train_data = np_to_pytorch_batch(train_data)
            obs = train_data['observations']
            next_obs = train_data['next_observations']
            if self.goal_conditioned:
                goals = train_data['resampled_goals']
                train_data['observations'] = torch.cat((obs, goals), dim=1)
                train_data['next_observations'] = torch.cat((next_obs, goals),
                                                            dim=1)
            self.train_from_torch(train_data)

            logger.record_dict(self.eval_statistics)
            logger.dump_tabular(with_prefix=True, with_timestamp=False)

        logger.remove_tabular_output(
            'pretrain_q.csv',
            relative_to_snapshot_dir=True,
        )
        logger.add_tabular_output(
            'progress.csv',
            relative_to_snapshot_dir=True,
        )
示例#15
0
 def _try_to_offline_eval(self, epoch):
     start_time = time.time()
     logger.save_extra_data(self.get_extra_data_to_save(epoch))
     self.offline_evaluate(epoch)
     params = self.get_epoch_snapshot(epoch)
     logger.save_itr_params(epoch, params)
     table_keys = logger.get_table_key_set()
     if self._old_table_keys is not None:
         assert table_keys == self._old_table_keys, (
             "Table keys cannot change from iteration to iteration.")
     self._old_table_keys = table_keys
     logger.dump_tabular(with_prefix=False, with_timestamp=False)
     logger.log("Eval Time: {0}".format(time.time() - start_time))
示例#16
0
def simulate_policy(args):
    if args.pause:
        import ipdb; ipdb.set_trace()
    data = pickle.load(open(args.file, "rb")) # joblib.load(args.file)
    if 'policy' in data:
        policy = data['policy']
    elif 'evaluation/policy' in data:
        policy = data['evaluation/policy']

    if 'env' in data:
        env = data['env']
    elif 'evaluation/env' in data:
        env = data['evaluation/env']

    if isinstance(env, RemoteRolloutEnv):
        env = env._wrapped_env
    print("Policy loaded")
    if args.gpu:
        ptu.set_gpu_mode(True)
        policy.to(ptu.device)
    else:
        ptu.set_gpu_mode(False)
        policy.to(ptu.device)
    if isinstance(env, VAEWrappedEnv):
        env.mode(args.mode)
    if args.enable_render or hasattr(env, 'enable_render'):
        # some environments need to be reconfigured for visualization
        env.enable_render()
    if args.multitaskpause:
        env.pause_on_goal = True
    if isinstance(policy, PyTorchModule):
        policy.train(False)
    paths = []
    while True:
        paths.append(multitask_rollout(
            env,
            policy,
            max_path_length=args.H,
            render=not args.hide,
            observation_key=data.get('evaluation/observation_key', 'observation'),
            desired_goal_key=data.get('evaluation/desired_goal_key', 'desired_goal'),
        ))
        if hasattr(env, "log_diagnostics"):
            env.log_diagnostics(paths)
        if hasattr(env, "get_diagnostics"):
            for k, v in env.get_diagnostics(paths).items():
                logger.record_tabular(k, v)
        logger.dump_tabular()
示例#17
0
def simulate_policy(args):
    dir = args.path
    data = joblib.load("{}/params.pkl".format(dir))
    env = data['env']
    model_params = data['model_params']
    mpc_params = data['mpc_params']
    # dyn_model = NNDynamicsModel(env=env, **model_params)
    # mpc_controller = MPCcontroller(env=env,
    #                                dyn_model=dyn_model,
    #                                **mpc_params)
    tf_path_meta = "{}/tf_out-0.meta".format(dir)
    tf_path = "{}/tf_out-0".format(dir)

    with tf.Session() as sess:
        new_saver = tf.train.import_meta_graph(tf_path_meta)
        new_saver.restore(sess, tf_path)

    env = data['env']
    if isinstance(env, RemoteRolloutEnv):
        env = env._wrapped_env
    print("Policy loaded")
    if args.gpu:
        set_gpu_mode(True)
        policy.to(ptu.device)
    if args.pause:
        import ipdb
        ipdb.set_trace()
    if isinstance(policy, PyTorchModule):
        policy.train(False)
    while True:
        try:
            path = rollout(
                env,
                policy,
                max_path_length=args.H,
                animated=True,
            )
            env.log_diagnostics([path])
            policy.log_diagnostics([path])
            logger.dump_tabular()
        # Hack for now. Not sure why rollout assumes that close is an
        # keyword argument
        except TypeError as e:
            if (str(e) != "render() got an unexpected keyword "
                    "argument 'close'"):
                raise e
def create_policy(variant):
    bottom_snapshot = joblib.load(variant['bottom_path'])
    column_snapshot = joblib.load(variant['column_path'])
    policy = variant['combiner_class'](
        policy1=bottom_snapshot['naf_policy'],
        policy2=column_snapshot['naf_policy'],
    )
    env = bottom_snapshot['env']
    logger.save_itr_params(0, dict(
        policy=policy,
        env=env,
    ))
    path = rollout(
        env,
        policy,
        max_path_length=variant['max_path_length'],
        animated=variant['render'],
    )
    env.log_diagnostics([path])
    logger.dump_tabular()
示例#19
0
def simulate_policy(args):
    data = joblib.load(args.file)
    qfs = data['qfs']
    env = data['env']
    print("Data loaded")
    if args.pause:
        import ipdb; ipdb.set_trace()
    for qf in qfs:
        qf.train(False)
    paths = []
    while True:
        paths.append(finite_horizon_rollout(
            env,
            qfs,
            max_path_length=args.H,
            max_T=args.mt,
        ))
        if hasattr(env, "log_diagnostics"):
            env.log_diagnostics(paths)
        logger.dump_tabular()
def experiment(variant):
    num_rollouts = variant['num_rollouts']
    H = variant['H']
    render = variant['render']
    env = MultitaskPoint2DEnv()
    qf = PerfectPoint2DQF()
    policy = variant['policy_class'](qf, env, **variant['policy_params'])
    paths = []
    for _ in range(num_rollouts):
        goal = env.sample_goal_state_for_rollout()
        path = multitask_rollout(
            env,
            policy,
            goal,
            discount=0,
            max_path_length=H,
            animated=render,
        )
        paths.append(path)
    env.log_diagnostics(paths)
    logger.dump_tabular(with_timestamp=False)
示例#21
0
    def evaluate(self, epoch):
        """
        Perform evaluation for this algorithm.

        :param epoch: The epoch number.
        :param exploration_paths: List of dicts, each representing a path.
        """
        statistics = OrderedDict()
        train_batch = self.get_batch(training=True)
        statistics.update(self._statistics_from_batch(train_batch, "Train"))
        validation_batch = self.get_batch(training=False)
        statistics.update(
            self._statistics_from_batch(validation_batch, "Validation")
        )

        statistics['QF Loss Validation - Train Gap'] = (
            statistics['Validation QF Loss Mean']
            - statistics['Train QF Loss Mean']
        )
        statistics['Epoch'] = epoch
        for key, value in statistics.items():
            logger.record_tabular(key, value)
        logger.dump_tabular(with_prefix=False, with_timestamp=False)
示例#22
0
def experiment(variant):
    path = variant['path']
    policy_class = variant['policy_class']
    policy_params = variant['policy_params']
    horizon = variant['horizon']
    num_rollouts = variant['num_rollouts']
    discount = variant['discount']
    stat_name = variant['stat_name']

    data = joblib.load(path)
    env = data['env']
    qf = data['qf']
    qf_argmax_policy = data['policy']
    policy = policy_class(
        qf,
        env,
        qf_argmax_policy,
        **policy_params
    )

    paths = []
    for _ in range(num_rollouts):
        goal = env.sample_goal_for_rollout()
        path = multitask_rollout(
            env,
            policy,
            goal,
            discount,
            max_path_length=horizon,
            animated=False,
            decrement_discount=False,
        )
        paths.append(path)
    env.log_diagnostics(paths)
    results = logger.get_table_dict()
    logger.dump_tabular()
    return results[stat_name]
示例#23
0
def simulate_policy(args):
    data = joblib.load(args.file)
    policy = data['mpc_controller']
    env = data['env']
    print("Policy loaded")
    if args.pause:
        import ipdb
        ipdb.set_trace()
    policy.cost_fn = env.cost_fn
    policy.env = env
    if args.T:
        policy.mpc_horizon = args.T
    paths = []
    while True:
        paths.append(
            rollout(
                env,
                policy,
                max_path_length=args.H,
                animated=True,
            ))
        if hasattr(env, "log_diagnostics"):
            env.log_diagnostics(paths)
        logger.dump_tabular()
示例#24
0
def experiment(variant):
    num_rollouts = variant['num_rollouts']
    data = joblib.load(variant['qf_path'])
    qf = data['qf']
    env = data['env']
    qf_policy = data['policy']
    if ptu.gpu_enabled():
        qf.to(ptu.device)
        qf_policy.to(ptu.device)
    if isinstance(qf, VectorizedGoalStructuredUniversalQfunction):
        policy = UnconstrainedOcWithImplicitModel(qf, env, qf_policy,
                                                  **variant['policy_params'])
    else:
        policy = UnconstrainedOcWithGoalConditionedModel(
            qf, env, qf_policy, **variant['policy_params'])
    paths = []
    for _ in range(num_rollouts):
        goal = env.sample_goal_for_rollout()
        print("goal", goal)
        path = multitask_rollout(env, policy, goal,
                                 **variant['rollout_params'])
        paths.append(path)
    env.log_diagnostics(paths)
    logger.dump_tabular(with_timestamp=False)
示例#25
0
def train_rfeatures_model(variant, return_data=False):
    from railrl.misc.ml_util import PiecewiseLinearSchedule
    # from railrl.torch.vae.conv_vae import (
    #     ConvVAE, ConvResnetVAE
    # )
    import railrl.torch.vae.conv_vae as conv_vae
    # from railrl.torch.vae.vae_trainer import ConvVAETrainer
    from railrl.launchers.experiments.ashvin.rfeatures.rfeatures_model import TimestepPredictionModel
    from railrl.launchers.experiments.ashvin.rfeatures.rfeatures_trainer import TimePredictionTrainer
    from railrl.core import logger
    import railrl.torch.pytorch_util as ptu
    from railrl.pythonplusplus import identity
    import torch
    output_classes = variant["output_classes"]
    representation_size = variant["representation_size"]
    batch_size = variant["batch_size"]
    variant['dataset_kwargs']["output_classes"] = output_classes
    train_dataset, test_dataset, info = get_data(variant['dataset_kwargs'])

    num_train_workers = variant.get("num_train_workers",
                                    0)  # 0 uses main process (good for pdb)
    train_dataset_loader = InfiniteBatchLoader(
        train_dataset,
        batch_size=batch_size,
        num_workers=num_train_workers,
    )
    test_dataset_loader = InfiniteBatchLoader(
        test_dataset,
        batch_size=batch_size,
    )

    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if variant.get('decoder_activation', None) == 'sigmoid':
        decoder_activation = torch.nn.Sigmoid()
    else:
        decoder_activation = identity
    architecture = variant['model_kwargs'].get('architecture', None)
    if not architecture and variant.get('imsize') == 84:
        architecture = conv_vae.imsize84_default_architecture
    elif not architecture and variant.get('imsize') == 48:
        architecture = conv_vae.imsize48_default_architecture
    variant['model_kwargs']['architecture'] = architecture

    model_class = variant.get('model_class', TimestepPredictionModel)
    model = model_class(
        representation_size,
        decoder_output_activation=decoder_activation,
        output_classes=output_classes,
        **variant['model_kwargs'],
    )
    # model = torch.nn.DataParallel(model)
    model.to(ptu.device)

    variant['trainer_kwargs']['batch_size'] = batch_size
    trainer_class = variant.get('trainer_class', TimePredictionTrainer)
    trainer = trainer_class(
        model,
        **variant['trainer_kwargs'],
    )
    save_period = variant['save_period']

    trainer.dump_trajectory_rewards(
        "initial", dict(train=train_dataset.dataset,
                        test=test_dataset.dataset))

    dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False)
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        trainer.train_epoch(epoch, train_dataset_loader, batches=10)
        trainer.test_epoch(epoch, test_dataset_loader, batches=1)

        if should_save_imgs:
            trainer.dump_reconstructions(epoch)

        trainer.dump_trajectory_rewards(
            epoch, dict(train=train_dataset.dataset,
                        test=test_dataset.dataset), should_save_imgs)

        stats = trainer.get_diagnostics()
        for k, v in stats.items():
            logger.record_tabular(k, v)
        logger.dump_tabular()
        trainer.end_epoch(epoch)

        if epoch % 50 == 0:
            logger.save_itr_params(epoch, model)
    logger.save_extra_data(model, 'vae.pkl', mode='pickle')
    if return_data:
        return model, train_dataset, test_dataset
    return model
def get_n_train_vae(latent_dim,
                    env,
                    vae_train_epochs,
                    num_image_examples,
                    vae_kwargs,
                    vae_trainer_kwargs,
                    vae_architecture,
                    vae_save_period=10,
                    vae_test_p=.9,
                    decoder_activation='sigmoid',
                    vae_class='VAE',
                    **kwargs):
    env.goal_sampling_mode = 'test'
    image_examples = unnormalize_image(
        env.sample_goals(num_image_examples)['desired_goal'])
    n = int(num_image_examples * vae_test_p)
    train_dataset = ImageObservationDataset(image_examples[:n, :])
    test_dataset = ImageObservationDataset(image_examples[n:, :])

    if decoder_activation == 'sigmoid':
        decoder_activation = torch.nn.Sigmoid()

    vae_class = vae_class.lower()
    if vae_class == 'VAE'.lower():
        vae_class = ConvVAE
    elif vae_class == 'SpatialVAE'.lower():
        vae_class = SpatialAutoEncoder
    else:
        raise RuntimeError("Invalid VAE Class: {}".format(vae_class))

    vae = vae_class(latent_dim,
                    architecture=vae_architecture,
                    decoder_output_activation=decoder_activation,
                    **vae_kwargs)

    trainer = ConvVAETrainer(vae, **vae_trainer_kwargs)

    logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True)
    logger.add_tabular_output('vae_progress.csv',
                              relative_to_snapshot_dir=True)
    for epoch in range(vae_train_epochs):
        should_save_imgs = (epoch % vae_save_period == 0)
        trainer.train_epoch(epoch, train_dataset)
        trainer.test_epoch(epoch, test_dataset)

        if should_save_imgs:
            trainer.dump_reconstructions(epoch)
            trainer.dump_samples(epoch)
        stats = trainer.get_diagnostics()
        for k, v in stats.items():
            logger.record_tabular(k, v)

        logger.dump_tabular()
        trainer.end_epoch(epoch)

        if epoch % 50 == 0:
            logger.save_itr_params(epoch, vae)
    logger.save_extra_data(vae, 'vae.pkl', mode='pickle')
    logger.remove_tabular_output('vae_progress.csv',
                                 relative_to_snapshot_dir=True)
    logger.add_tabular_output('progress.csv', relative_to_snapshot_dir=True)
    return vae
示例#27
0
         goal = env.sample_goal_for_rollout()
         goal[7:14] = 0
         path = multitask_rollout(
             env,
             original_policy,
             # env.multitask_goal,
             goal,
             init_tau=10,
             max_path_length=args.H,
             animated=not args.hide,
             cycle_tau=True,
             decrement_tau=False,
         )
         if hasattr(env, "log_diagnostics"):
             env.log_diagnostics([path])
         logger.dump_tabular()
 else:
     for weight in [1]:
         for num_simulated_paths in [args.npath]:
             print("")
             print("weight", weight)
             print("num_simulated_paths", num_simulated_paths)
             policy = CollocationMpcController(
                 env,
                 implicit_model,
                 original_policy,
                 num_simulated_paths=num_simulated_paths,
                 feasibility_weight=weight,
             )
             policy.train(False)
             paths = []
示例#28
0
def train(dataset_generator,
          n_start_samples,
          projection=project_samples_square_np,
          n_samples_to_add_per_epoch=1000,
          n_epochs=100,
          save_period=10,
          append_all_data=True,
          full_variant=None,
          dynamics_noise=0,
          num_bins=5,
          weight_type='sqrt_inv_p',
          **kwargs):
    report = HTMLReport(
        logger.get_snapshot_dir() + '/report.html',
        images_per_row=3,
    )
    dynamics = Dynamics(projection, dynamics_noise)
    if full_variant:
        report.add_header("Variant")
        report.add_text(
            json.dumps(
                ppp.dict_to_safe_json(full_variant, sort=True),
                indent=2,
            ))

    orig_train_data = dataset_generator(n_start_samples)
    train_data = orig_train_data

    heatmap_imgs = []
    sample_imgs = []
    entropies = []
    tvs_to_uniform = []
    """
    p_theta = previous iteration's model
    p_new = this iteration's distribution
    """
    p_theta = Histogram(num_bins, weight_type=weight_type)
    for epoch in range(n_epochs):
        logger.record_tabular('Epoch', epoch)
        logger.record_tabular('Entropy ', p_theta.entropy())
        logger.record_tabular('KL from uniform', p_theta.kl_from_uniform())
        logger.record_tabular('TV to uniform', p_theta.tv_to_uniform())
        entropies.append(p_theta.entropy())
        tvs_to_uniform.append(p_theta.tv_to_uniform())

        samples = p_theta.sample(n_samples_to_add_per_epoch)
        empirical_samples = dynamics(samples)

        if append_all_data:
            train_data = np.vstack((train_data, empirical_samples))
        else:
            train_data = np.vstack((orig_train_data, empirical_samples))

        if epoch == 0 or (epoch + 1) % save_period == 0:
            report.add_text("Epoch {}".format(epoch))
            heatmap_img = visualize_histogram(epoch, p_theta, report)
            sample_img = visualize_samples(epoch, train_data, p_theta, report,
                                           dynamics)
            heatmap_imgs.append(heatmap_img)
            sample_imgs.append(sample_img)
            report.save()

            from PIL import Image
            Image.fromarray(heatmap_img).save(logger.get_snapshot_dir() +
                                              '/heatmap{}.png'.format(epoch))
            Image.fromarray(sample_img).save(logger.get_snapshot_dir() +
                                             '/samples{}.png'.format(epoch))
        weights = p_theta.compute_per_elem_weights(train_data)
        p_new = Histogram(num_bins, weight_type=weight_type)
        p_new.fit(
            train_data,
            weights=weights,
        )
        p_theta = p_new
        logger.dump_tabular()
    plot_curves([
        ("Entropy", entropies),
        ("TVs to Uniform", tvs_to_uniform),
    ], report)
    report.add_text("Max entropy: {}".format(p_theta.max_entropy()))
    report.save()

    heatmap_video = np.stack(heatmap_imgs)
    sample_video = np.stack(sample_imgs)

    vwrite(
        logger.get_snapshot_dir() + '/heatmaps.mp4',
        heatmap_video,
    )
    vwrite(
        logger.get_snapshot_dir() + '/samples.mp4',
        sample_video,
    )
    try:
        gif(
            logger.get_snapshot_dir() + '/samples.gif',
            sample_video,
        )
        gif(
            logger.get_snapshot_dir() + '/heatmaps.gif',
            heatmap_video,
        )
        report.add_image(
            logger.get_snapshot_dir() + '/samples.gif',
            "Samples GIF",
            is_url=True,
        )
        report.add_image(
            logger.get_snapshot_dir() + '/heatmaps.gif',
            "Heatmaps GIF",
            is_url=True,
        )
        report.save()
    except ImportError as e:
        print(e)
示例#29
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('file', type=str, help='path to the snapshot file')
    parser.add_argument('--H',
                        type=int,
                        default=300,
                        help='Max length of rollout')
    parser.add_argument('--nrolls',
                        type=int,
                        default=1,
                        help='Number of rollout per eval')
    parser.add_argument('--verbose', action='store_true')
    parser.add_argument('--mtau', type=float, help='Max tau value')
    parser.add_argument('--grid', action='store_true')
    parser.add_argument('--gpu', action='store_true')
    parser.add_argument('--load', action='store_true')
    parser.add_argument('--hide', action='store_true')
    parser.add_argument('--pause', action='store_true')
    parser.add_argument('--cycle', help='cycle tau', action='store_true')
    args = parser.parse_args()

    data = joblib.load(args.file)
    env = data['env']
    if 'policy' in data:
        policy = data['policy']
    else:
        policy = data['exploration_policy']
    qf = data['qf']
    policy.train(False)
    qf.train(False)

    if args.pause:
        import ipdb
        ipdb.set_trace()

    if args.gpu:
        ptu.set_gpu_mode(True)
        policy.to(ptu.device)

    if args.mtau is None:
        print("Defaulting max tau to 10.")
        max_tau = 10
    else:
        max_tau = args.mtau

    while True:
        paths = []
        for _ in range(args.nrolls):
            goal = env.sample_goal_for_rollout()
            print("goal", goal)
            env.set_goal(goal)
            policy.set_goal(goal)
            policy.set_tau(max_tau)
            path = rollout(
                env,
                policy,
                qf,
                init_tau=max_tau,
                max_path_length=args.H,
                animated=not args.hide,
                cycle_tau=args.cycle,
            )
            paths.append(path)
        env.log_diagnostics(paths)
        for key, value in get_generic_path_information(paths).items():
            logger.record_tabular(key, value)
        logger.dump_tabular()
示例#30
0
def simulate_policy(args):
    data = joblib.load(args.file)
    if 'eval_policy' in data:
        policy = data['eval_policy']
    elif 'policy' in data:
        policy = data['policy']
    elif 'exploration_policy' in data:
        policy = data['exploration_policy']
    else:
        raise Exception("No policy found in loaded dict. Keys: {}".format(
            data.keys()))

    env = data['env']

    env.mode("video_env")
    env.decode_goals = True

    if hasattr(env, 'enable_render'):
        # some environments need to be reconfigured for visualization
        env.enable_render()

    if args.gpu:
        set_gpu_mode(True)
        policy.to(ptu.device)
        if hasattr(env, "vae"):
            env.vae.to(ptu.device)
    else:
        # make sure everything is on the CPU
        set_gpu_mode(False)
        policy.cpu()
        if hasattr(env, "vae"):
            env.vae.cpu()

    if args.pause:
        import ipdb
        ipdb.set_trace()
    if isinstance(policy, PyTorchModule):
        policy.train(False)
    ROWS = 3
    COLUMNS = 6
    dirname = osp.dirname(args.file)
    input_file_name = os.path.splitext(os.path.basename(args.file))[0]
    filename = osp.join(dirname, "video_{}.mp4".format(input_file_name))
    rollout_function = create_rollout_function(
        multitask_rollout,
        observation_key='observation',
        desired_goal_key='desired_goal',
    )
    paths = dump_video(
        env,
        policy,
        filename,
        rollout_function,
        ROWS=ROWS,
        COLUMNS=COLUMNS,
        horizon=args.H,
        dirname_to_save_images=dirname,
        subdirname="rollouts_" + input_file_name,
    )

    if hasattr(env, "log_diagnostics"):
        env.log_diagnostics(paths)
    logger.dump_tabular()