예제 #1
0
    def _try_to_eval(self, epoch, eval_paths=None):
        logger.save_extra_data(self.get_extra_data_to_save(epoch))
        if self._can_evaluate():
            self.evaluate(epoch, eval_paths=eval_paths)

            params = self.get_epoch_snapshot(epoch)
            logger.save_itr_params(epoch, params)
            table_keys = logger.get_table_key_set()
            if self._old_table_keys is not None:
                assert table_keys == self._old_table_keys, (
                    "Table keys cannot change from iteration to iteration.")
            self._old_table_keys = table_keys

            logger.record_tabular(
                "Number of train steps total",
                self._n_train_steps_total,
            )
            logger.record_tabular(
                "Number of env steps total",
                self._n_env_steps_total,
            )
            logger.record_tabular(
                "Number of rollouts total",
                self._n_rollouts_total,
            )

            times_itrs = gt.get_times().stamps.itrs
            train_time = times_itrs['train'][-1]
            sample_time = times_itrs['sample'][-1]
            eval_time = times_itrs['eval'][-1] if epoch > 0 else 0
            epoch_time = train_time + sample_time + eval_time
            total_time = gt.get_times().total

            logger.record_tabular('Train Time (s)', train_time)
            logger.record_tabular('(Previous) Eval Time (s)', eval_time)
            logger.record_tabular('Sample Time (s)', sample_time)
            logger.record_tabular('Epoch Time (s)', epoch_time)
            logger.record_tabular('Total Train Time (s)', total_time)
            logger.record_tabular("Epoch", epoch)
            logger.dump_tabular(with_prefix=False, with_timestamp=False)
        else:
            logger.log("Skipping eval for now.")
예제 #2
0
def simulate_policy(args):
    data = torch.load(str(args.file))
    #data = joblib.load(str(args.file))
    policy = data['evaluation/policy']
    env = NormalizedBoxEnv(HalfCheetahEnv())
    #env = data['evaluation/env']
    print("Policy loaded")
    if args.gpu:
        set_gpu_mode(True)
        policy.cuda()
    while True:
        path = rollout(
            env,
            policy,
            max_path_length=args.H,
            render=True,
        )
        if hasattr(env, "log_diagnostics"):
            env.log_diagnostics([path])
        logger.dump_tabular()
예제 #3
0
def create_policy(variant):
    bottom_snapshot = joblib.load(variant['bottom_path'])
    column_snapshot = joblib.load(variant['column_path'])
    policy = variant['combiner_class'](
        policy1=bottom_snapshot['naf_policy'],
        policy2=column_snapshot['naf_policy'],
    )
    env = bottom_snapshot['env']
    logger.save_itr_params(0, dict(
        policy=policy,
        env=env,
    ))
    path = rollout(
        env,
        policy,
        max_path_length=variant['max_path_length'],
        animated=variant['render'],
    )
    env.log_diagnostics([path])
    logger.dump_tabular()
예제 #4
0
def simulate_policy(args):
    data = joblib.load(args.file)  # Pickle is internally used using joblib
    policy = data['policy']
    env = data['env']
    print("Policy loaded")
    if args.gpu:
        set_gpu_mode(True)
        policy.cuda()
    if isinstance(policy, PyTorchModule):
        policy.train(False)
    while True:
        path = rollout(
            env,
            policy,
            max_path_length=args.H,
            animated=False,
        )
        if hasattr(env, "log_diagnostics"):
            env.log_diagnostics([path])
        logger.dump_tabular()
예제 #5
0
def simulate_policy(args):
    data = torch.load(args.file)
    policy = data['evaluation/policy']
    env = data['evaluation/env']
    print("Policy loaded")
    if args.gpu:
        set_gpu_mode(True)
        policy.cuda()
    while True:
        path = multitask_rollout(
            env,
            policy,
            max_path_length=args.H,
            render=True,
            observation_key='observation',
            desired_goal_key='desired_goal',
        )
        if hasattr(env, "log_diagnostics"):
            env.log_diagnostics([path])
        logger.dump_tabular()
예제 #6
0
    def pretrain_q_with_bc_data(self, batch_size):
        logger.remove_tabular_output('progress.csv',
                                     relative_to_snapshot_dir=True)
        logger.add_tabular_output('pretrain_q.csv',
                                  relative_to_snapshot_dir=True)

        prev_time = time.time()
        for i in range(self.num_pretrain_steps):
            self.eval_statistics = dict()
            if i % self.pretraining_logging_period == 0:
                self._need_to_update_eval_statistics = True
            train_data = self.replay_buffer.random_batch(self.bc_batch_size)
            train_data = np_to_pytorch_batch(train_data)
            obs = train_data['observations']
            next_obs = train_data['next_observations']
            # goals = train_data['resampled_goals']
            train_data['observations'] = obs  # torch.cat((obs, goals), dim=1)
            train_data[
                'next_observations'] = next_obs  # torch.cat((next_obs, goals), dim=1)
            self.train_from_torch(train_data, pretrain=True)

            if i % self.pretraining_logging_period == 0:
                self.eval_statistics["batch"] = i
                self.eval_statistics["epoch_time"] = time.time() - prev_time
                stats_with_prefix = add_prefix(self.eval_statistics,
                                               prefix="trainer/")
                logger.record_dict(stats_with_prefix)
                logger.dump_tabular(with_prefix=True, with_timestamp=False)
                prev_time = time.time()

        logger.remove_tabular_output(
            'pretrain_q.csv',
            relative_to_snapshot_dir=True,
        )
        logger.add_tabular_output(
            'progress.csv',
            relative_to_snapshot_dir=True,
        )

        self._need_to_update_eval_statistics = True
        self.eval_statistics = dict()
예제 #7
0
def simulate_policy(fpath,
                    env_name,
                    seed,
                    max_path_length,
                    num_eval_steps,
                    headless,
                    max_eps,
                    verbose=True,
                    pause=False):
    data = torch.load(fpath, map_location=ptu.device)
    policy = data['evaluation/policy']
    policy.to(ptu.device)
    # make new env, reloading with data['evaluation/env'] seems to make bug
    env = gym.make(env_name, **{"headless": headless, "verbose": False})
    env.seed(seed)
    if pause:
        input("Waiting to start.")
    path_collector = MdpPathCollector(env, policy)
    paths = path_collector.collect_new_paths(
        max_path_length,
        num_eval_steps,
        discard_incomplete_paths=True,
    )

    if max_eps:
        paths = paths[:max_eps]
    if verbose:
        completions = sum([
            info["completed"] for path in paths for info in path["env_infos"]
        ])
        print("Completed {} out of {}".format(completions, len(paths)))
        # plt.plot(paths[0]["actions"])
        # plt.show()
        # plt.plot(paths[2]["observations"])
        # plt.show()
        logger.record_dict(
            eval_util.get_generic_path_information(paths),
            prefix="evaluation/",
        )
        logger.dump_tabular()
    return paths
예제 #8
0
    def evaluate(self):
        eval_statistics = OrderedDict()
        self.mlp.eval()
        self.encoder.eval()
        # for i in range(self.min_context_size, self.max_context_size+1):
        for i in range(1, 12):
            # prep the batches
            context_batch, mask, obs_task_params, classification_inputs, classification_labels = self._get_eval_batch(
                self.num_tasks_per_eval, i)
            # print(len(context_batch))
            # print(len(context_batch[0]))

            post_dist = self.encoder(context_batch, mask)
            z = post_dist.sample()  # N_tasks x Dim
            # z = post_dist.mean

            obs_task_params = Variable(ptu.from_numpy(obs_task_params))
            # print(obs_task_params)

            if self.training_regression:
                preds = self.mlp(z)
                loss = self.mse(preds, obs_task_params)
                eval_statistics['Loss for %d' % i] = np.mean(
                    ptu.get_numpy(loss))
            else:
                repeated_z = z.repeat(
                    1, self.classification_batch_size_per_task).view(
                        -1, z.size(1))
                mlp_input = torch.cat([classification_inputs, repeated_z],
                                      dim=-1)
                preds = self.mlp(mlp_input)
                # loss = self.bce(preds, classification_labels)
                class_preds = (preds > 0).type(preds.data.type())
                accuracy = (class_preds == classification_labels).type(
                    torch.FloatTensor).mean()
                eval_statistics['Acc for %d' % i] = np.mean(
                    ptu.get_numpy(accuracy))

        for key, value in eval_statistics.items():
            logger.record_tabular(key, value)
        logger.dump_tabular(with_prefix=False, with_timestamp=False)
예제 #9
0
    def _try_to_eval(self, epoch):

        if self._can_evaluate():
            # save if it's time to save
            if epoch % self.freq_saving == 0:
                logger.save_extra_data(self.get_extra_data_to_save(epoch))
                params = self.get_epoch_snapshot(epoch)
                logger.save_itr_params(epoch, params)

            self.evaluate(epoch)

            logger.record_tabular(
                "Number of train calls total",
                self._n_train_steps_total,
            )
            logger.record_tabular(
                "Number of env steps total",
                self._n_env_steps_total,
            )
            logger.record_tabular(
                "Number of rollouts total",
                self._n_rollouts_total,
            )

            times_itrs = gt.get_times().stamps.itrs
            train_time = times_itrs['train'][-1]
            sample_time = times_itrs['sample'][-1]
            eval_time = times_itrs['eval'][-1] if epoch > 0 else 0
            epoch_time = train_time + sample_time + eval_time
            total_time = gt.get_times().total

            logger.record_tabular('Train Time (s)', train_time)
            logger.record_tabular('(Previous) Eval Time (s)', eval_time)
            logger.record_tabular('Sample Time (s)', sample_time)
            logger.record_tabular('Epoch Time (s)', epoch_time)
            logger.record_tabular('Total Train Time (s)', total_time)

            logger.record_tabular("Epoch", epoch)
            logger.dump_tabular(with_prefix=False, with_timestamp=False)
        else:
            logger.log("Skipping eval for now.")
예제 #10
0
def simulate_policy(args):
    data = joblib.load(args.file)
    import ipdb; ipdb.set_trace()
    policy = data['exploration_policy']  # ? TODO, eval ?
    env = data['env']
    print("Policy loaded")
    if args.gpu:
        set_gpu_mode("gpu")
        policy.cuda()
    if isinstance(policy, PyTorchModule):
        policy.train(False)
    while True:
        path = rollout(
            env,
            policy,
            max_path_length=args.H,
            animated=True,
        )
        if hasattr(env, "log_diagnostics"):
            env.log_diagnostics([path])
        logger.dump_tabular()
예제 #11
0
    def test_epoch(
            self,
            epoch,
    ):
        self.model.eval()
        val_losses = []
        per_dim_losses = np.zeros((self.num_batches, self.y_train.shape[1]))
        for batch in range(self.num_batches):
            inputs_np, labels_np = self.random_batch(self.X_test, self.y_test, batch_size=self.batch_size)
            inputs, labels = ptu.Variable(ptu.from_numpy(inputs_np)), ptu.Variable(ptu.from_numpy(labels_np))
            outputs = self.model(inputs)
            loss = self.criterion(outputs, labels)
            val_losses.append(loss.data[0])
            per_dim_loss = np.mean(np.power(ptu.get_numpy(outputs - labels), 2), axis=0)
            per_dim_losses[batch] = per_dim_loss

        logger.record_tabular("test/epoch", epoch)
        logger.record_tabular("test/loss", np.mean(np.array(val_losses)))
        for i in range(self.y_train.shape[1]):
            logger.record_tabular("test/dim "+str(i)+" loss", np.mean(per_dim_losses[:, i]))
        logger.dump_tabular()
예제 #12
0
def simulate_policy(args):
    data = torch.load(args.file)
    policy = data['evaluation/policy']
    env = data['evaluation/env']
    print("Policy loaded")
    if args.gpu:
        set_gpu_mode(True)
        policy.cuda()
    num_fail = 0
    for _ in range(args.ep):
        path = rollout(
            env,
            policy,
            max_path_length=args.H,
            render=False,
            sleep=args.S,
        )
        if np.any(path['rewards'] == -1):
            num_fail += 1
            if args.de:
                last_obs = np.moveaxis(
                    np.reshape(path['observations'][-1], (3, 33, 33)), 0, -1)
                last_next_obs = np.moveaxis(
                    np.reshape(path['next_observations'][-1], (3, 33, 33)), 0,
                    -1)
                last_obs = (last_obs * 33 + 128).astype(np.uint8)
                last_next_obs = (last_next_obs * 33 + 128).astype(np.uint8)
                fig = plt.figure(figsize=(10, 10))
                fig.add_subplot(2, 1, 1)
                plt.imshow(last_obs)
                fig.add_subplot(2, 1, 2)
                plt.imshow(last_next_obs)
                plt.show()
                plt.close()

        if hasattr(env, "log_diagnostics"):
            env.log_diagnostics([path])
        logger.dump_tabular()

    print('number of failures:', num_fail)
예제 #13
0
파일: sim_fh_dqn.py 프로젝트: jcoreyes/erl
def simulate_policy(args):
    data = joblib.load(args.file)
    qfs = data['qfs']
    env = data['env']
    print("Data loaded")
    if args.pause:
        import ipdb
        ipdb.set_trace()
    for qf in qfs:
        qf.train(False)
    paths = []
    while True:
        paths.append(
            finite_horizon_rollout(
                env,
                qfs,
                max_path_length=args.H,
                max_T=args.mt,
            ))
        if hasattr(env, "log_diagnostics"):
            env.log_diagnostics(paths)
        logger.dump_tabular()
예제 #14
0
def simulate_policy(args):
    data = joblib.load(args.file)
    policy = data['policy']
    # env = data['env']
    from rlkit.envs.mujoco_manip_env import MujocoManipEnv
    env = MujocoManipEnv("SawyerLiftEnv", render=True)
    print("Policy loaded")
    if args.gpu:
        set_gpu_mode(True)
        policy.cuda()
    if isinstance(policy, PyTorchModule):
        policy.train(False)
    while True:
        path = rollout(
            env,
            policy,
            max_path_length=args.H,
            animated=True,
        )
        if hasattr(env, "log_diagnostics"):
            env.log_diagnostics([path])
        logger.dump_tabular()
예제 #15
0
    def evaluate(self, epoch):
        """
        Perform evaluation for this algorithm.

        :param epoch: The epoch number.
        :param exploration_paths: List of dicts, each representing a path.
        """
        statistics = OrderedDict()
        train_batch = self.get_batch(training=True)
        statistics.update(self._statistics_from_batch(train_batch, "Train"))
        validation_batch = self.get_batch(training=False)
        statistics.update(
            self._statistics_from_batch(validation_batch, "Validation")
        )

        statistics['QF Loss Validation - Train Gap'] = (
            statistics['Validation QF Loss Mean']
            - statistics['Train QF Loss Mean']
        )
        statistics['Epoch'] = epoch
        for key, value in statistics.items():
            logger.record_tabular(key, value)
        logger.dump_tabular(with_prefix=False, with_timestamp=False)
예제 #16
0
def simulate_policy(args):
    manager_data = torch.load(args.manager_file)
    worker_data = torch.load(args.worker_file)
    policy = manager_data['evaluation/policy']
    worker = worker_data['evaluation/policy']
    env = NormalizedBoxEnv(gym.make(str(args.env)))
    print("Policy loaded")
    if args.gpu:
        set_gpu_mode(True)
        policy.cuda()

    import cv2
    video = cv2.VideoWriter('ppo_dirichlet_diayn_bipedal_walker_hardcore.avi',
                            cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), 30,
                            (1200, 800))
    index = 0

    path = rollout(
        env,
        policy,
        worker,
        continuous=True,
        max_path_length=args.H,
        render=True,
    )
    if hasattr(env, "log_diagnostics"):
        env.log_diagnostics([path])
    logger.dump_tabular()

    for i, img in enumerate(path['images']):
        print(i)
        video.write(img[:, :, ::-1].astype(np.uint8))
        #        cv2.imwrite("frames/ppo_dirichlet_diayn_policy_bipedal_walker_hardcore/%06d.png" % index, img[:,:,::-1])
        index += 1

    video.release()
    print("wrote video")
예제 #17
0
def simulate_policy(args):
    if args.pause:
        import ipdb
        ipdb.set_trace()
    data = pickle.load(open(args.file, "rb"))
    policy = data['policy']
    env = data['env']
    print("Policy and environment loaded")
    if args.gpu:
        ptu.set_gpu_mode(True)
        policy.to(ptu.device)
    # if isinstance(env, VAEWrappedEnv):
    #     env.mode(args.mode)
    if args.enable_render or hasattr(env, 'enable_render'):
        # some environments need to be reconfigured for visualization
        env.enable_render()
    policy.train(False)
    paths = []
    for i in range(1000):
        paths.append(
            multitask_rollout(
                env,
                policy,
                max_path_length=args.H,
                animated=not args.hide,
                observation_key='observation',
                desired_goal_key='context',
            ))
        if hasattr(env, "log_diagnostics"):
            env.log_diagnostics(paths)
        if hasattr(env, "get_diagnostics"):
            for k, v in env.get_diagnostics(paths).items():
                logger.record_tabular(k, v)
        logger.dump_tabular()

    if 'point2d' in type(env.wrapped_env).__name__.lower():
        point2d(paths, args)
예제 #18
0
def simulate_policy(args):
    data = joblib.load(args.file)
    policy = data['mpc_controller']
    env = data['env']
    print("Policy loaded")
    if args.pause:
        import ipdb
        ipdb.set_trace()
    policy.cost_fn = env.cost_fn
    policy.env = env
    if args.T:
        policy.mpc_horizon = args.T
    paths = []
    while True:
        paths.append(
            rollout(
                env,
                policy,
                max_path_length=args.H,
                animated=True,
            ))
        if hasattr(env, "log_diagnostics"):
            env.log_diagnostics(paths)
        logger.dump_tabular()
예제 #19
0
def simulate_policy(args):
    data = torch.load(args.file)
    print(data)
    # policy = data['evaluation/policy']
    '''
    I don't know why but they did not save the policy for evalutaion.
    Instead of that, I used trainer/policy
    '''
    policy = data['trainer/policy']
    env = data['evaluation/env']
    print("Policy loaded")
    if args.gpu:
        set_gpu_mode(True)
        policy.cuda()
    while True:
        path = rollout(
            env,
            policy,
            max_path_length=args.H,
            render=True,
        )
        if hasattr(env, "log_diagnostics"):
            env.log_diagnostics([path])
        logger.dump_tabular()
def simulate_policy(args):
    data = joblib.load(args.file)
    policy = data['policy']
    env = data['env']
    print("Policy loaded")

    farmer = Farmer([('0.0.0.0', 1)])
    env_to_sim = farmer.force_acq_env()

    if args.gpu:
        set_gpu_mode(True)
        policy.cuda()
    if isinstance(policy, PyTorchModule):
        policy.train(False)
    while True:
        path = rollout(
            env_to_sim,
            policy,
            max_path_length=args.H,
            animated=False,
        )
        if hasattr(env, "log_diagnostics"):
            env.log_diagnostics([path])
        logger.dump_tabular()
예제 #21
0
def example(variant):
    import torch

    import rlkit.torch.pytorch_util as ptu

    print("Starting")
    logger.log(torch.__version__)
    date_format = "%m/%d/%Y %H:%M:%S %Z"
    date = datetime.now(tz=pytz.utc)
    logger.log("start")
    logger.log("Current date & time is: {}".format(date.strftime(date_format)))
    logger.log("Cuda available: {}".format(torch.cuda.is_available()))
    if torch.cuda.is_available():
        x = torch.randn(3)
        logger.log(str(x.to(ptu.device)))

    date = date.astimezone(timezone("US/Pacific"))
    logger.log("Local date & time is: {}".format(date.strftime(date_format)))
    for i in range(variant["num_seconds"]):
        logger.log("Tick, {}".format(i))
        time.sleep(1)
    logger.log("end")
    logger.log("Local date & time is: {}".format(date.strftime(date_format)))

    logger.log("start mujoco")
    from gym.envs.mujoco import HalfCheetahEnv

    e = HalfCheetahEnv()
    img = e.sim.render(32, 32)
    logger.log(str(sum(img)))
    logger.log("end mujoco")

    logger.record_tabular("Epoch", 1)
    logger.dump_tabular()
    logger.record_tabular("Epoch", 2)
    logger.dump_tabular()
    logger.record_tabular("Epoch", 3)
    logger.dump_tabular()
    print("Done")
예제 #22
0
    def pretrain_q_with_bc_data(self):
        """

        :return:
        """
        logger.remove_tabular_output('progress.csv',
                                     relative_to_snapshot_dir=True)
        logger.add_tabular_output('pretrain_q.csv',
                                  relative_to_snapshot_dir=True)

        self.update_policy = False
        # first train only the Q function
        for i in range(self.q_num_pretrain1_steps):
            self.eval_statistics = dict()

            train_data = self.replay_buffer.random_batch(self.bc_batch_size)
            train_data = np_to_pytorch_batch(train_data)
            obs = train_data['observations']
            next_obs = train_data['next_observations']
            # goals = train_data['resampled_goals']
            train_data['observations'] = obs  # torch.cat((obs, goals), dim=1)
            train_data[
                'next_observations'] = next_obs  # torch.cat((next_obs, goals), dim=1)
            self.train_from_torch(train_data, pretrain=True)
            if i % self.pretraining_logging_period == 0:
                stats_with_prefix = add_prefix(self.eval_statistics,
                                               prefix="trainer/")
                logger.record_dict(stats_with_prefix)
                logger.dump_tabular(with_prefix=True, with_timestamp=False)

        self.update_policy = True
        # then train policy and Q function together
        prev_time = time.time()
        for i in range(self.q_num_pretrain2_steps):
            self.eval_statistics = dict()
            if i % self.pretraining_logging_period == 0:
                self._need_to_update_eval_statistics = True
            train_data = self.replay_buffer.random_batch(self.bc_batch_size)
            train_data = np_to_pytorch_batch(train_data)
            obs = train_data['observations']
            next_obs = train_data['next_observations']
            # goals = train_data['resampled_goals']
            train_data['observations'] = obs  # torch.cat((obs, goals), dim=1)
            train_data[
                'next_observations'] = next_obs  # torch.cat((next_obs, goals), dim=1)
            self.train_from_torch(train_data, pretrain=True)

            if i % self.pretraining_logging_period == 0:
                self.eval_statistics["batch"] = i
                self.eval_statistics["epoch_time"] = time.time() - prev_time
                stats_with_prefix = add_prefix(self.eval_statistics,
                                               prefix="trainer/")
                logger.record_dict(stats_with_prefix)
                logger.dump_tabular(with_prefix=True, with_timestamp=False)
                prev_time = time.time()

        logger.remove_tabular_output(
            'pretrain_q.csv',
            relative_to_snapshot_dir=True,
        )
        logger.add_tabular_output(
            'progress.csv',
            relative_to_snapshot_dir=True,
        )

        self._need_to_update_eval_statistics = True
        self.eval_statistics = dict()

        if self.post_pretrain_hyperparams:
            self.set_algorithm_weights(**self.post_pretrain_hyperparams)
예제 #23
0
    def pretrain_policy_with_bc(
        self,
        policy,
        train_buffer,
        test_buffer,
        steps,
        label="policy",
    ):
        """Given a policy, first get its optimizer, then run the policy on the train buffer, get the
        losses, and back propagate the loss. After training on a batch, test on the test buffer and
        get the statistics

        :param policy:
        :param train_buffer:
        :param test_buffer:
        :param steps:
        :param label:
        :return:
        """
        logger.remove_tabular_output(
            'progress.csv',
            relative_to_snapshot_dir=True,
        )
        logger.add_tabular_output(
            'pretrain_%s.csv' % label,
            relative_to_snapshot_dir=True,
        )

        optimizer = self.optimizers[policy]
        prev_time = time.time()
        for i in range(steps):
            train_policy_loss, train_logp_loss, train_mse_loss, train_stats = self.run_bc_batch(
                train_buffer, policy)
            train_policy_loss = train_policy_loss * self.bc_weight

            optimizer.zero_grad()
            train_policy_loss.backward()
            optimizer.step()

            test_policy_loss, test_logp_loss, test_mse_loss, test_stats = self.run_bc_batch(
                test_buffer, policy)
            test_policy_loss = test_policy_loss * self.bc_weight

            if i % self.pretraining_logging_period == 0:
                stats = {
                    "pretrain_bc/batch":
                    i,
                    "pretrain_bc/Train Logprob Loss":
                    ptu.get_numpy(train_logp_loss),
                    "pretrain_bc/Test Logprob Loss":
                    ptu.get_numpy(test_logp_loss),
                    "pretrain_bc/Train MSE":
                    ptu.get_numpy(train_mse_loss),
                    "pretrain_bc/Test MSE":
                    ptu.get_numpy(test_mse_loss),
                    "pretrain_bc/train_policy_loss":
                    ptu.get_numpy(train_policy_loss),
                    "pretrain_bc/test_policy_loss":
                    ptu.get_numpy(test_policy_loss),
                    "pretrain_bc/epoch_time":
                    time.time() - prev_time,
                }

                logger.record_dict(stats)
                logger.dump_tabular(with_prefix=True, with_timestamp=False)
                pickle.dump(
                    self.policy,
                    open(logger.get_snapshot_dir() + '/bc_%s.pkl' % label,
                         "wb"))
                prev_time = time.time()

        logger.remove_tabular_output(
            'pretrain_%s.csv' % label,
            relative_to_snapshot_dir=True,
        )
        logger.add_tabular_output(
            'progress.csv',
            relative_to_snapshot_dir=True,
        )

        if self.post_bc_pretrain_hyperparams:
            self.set_algorithm_weights(**self.post_bc_pretrain_hyperparams)
예제 #24
0
def offpolicy_inference():
    import time
    from gym import wrappers

    filename = str(uuid.uuid4())

    gpu = True

    env, _, _ = prepare_env(args.env_name, args.visionmodel_path, **env_kwargs)

    snapshot = torch.load(args.load_name)
    policy = snapshot['evaluation/policy']
    if args.env_name.find('doorenv') > -1:
        policy.knob_noisy = args.knob_noisy
        policy.nn = env._wrapped_env.nn
        policy.visionnet_input = env_kwargs['visionnet_input']

    epi_counter = 1
    dooropen_counter = 0
    total_time = 0
    test_num = 100

    if evaluation:
        render = False
    else:
        if not args.unity:
            render = True
        else:
            render = False

    start_time = int(time.mktime(time.localtime()))

    if gpu:
        set_gpu_mode(True)
    while True:
        if args.env_name.find('doorenv') > -1:
            path, door_opened, opening_time = rollout(
                env,
                policy,
                max_path_length=512,
                doorenv=True,
                render=render,
                evaluate=True,
            )
            print("done first")
            if hasattr(env, "log_diagnostics"):
                env.log_diagnostics([path])
            logger.dump_tabular()
            if evaluation:
                env, _, _ = prepare_env(args.env_name, args.visionmodel_path,
                                        **env_kwargs)
                if door_opened:
                    dooropen_counter += 1
                    total_time += opening_time
                    eval_print(dooropen_counter, epi_counter, start_time,
                               total_time)

        else:
            path = rollout(
                env,
                policy,
                max_path_length=512,
                doorenv=False,
                render=render,
            )
            if hasattr(env, "log_diagnostics"):
                env.log_diagnostics([path])
            logger.dump_tabular()

        if evaluation:
            print("{} ep end >>>>>>>>>>>>>>>>>>>>>>>>".format(epi_counter))
            epi_counter += 1

            if args.env_name.find('door') > -1 and epi_counter > test_num:
                eval_print(dooropen_counter, epi_counter, start_time,
                           total_time)
                break
예제 #25
0
    else:
        max_tau = args.mtau

    env = data['env']
    policy = data['policy']
    policy.train(False)

    if args.gpu:
        ptu.set_gpu_mode(True)
        policy.cuda()

    while True:
        paths = []
        for _ in range(args.nrolls):
            goal = env.sample_goal_for_rollout()
            path = multitask_rollout(
                env,
                policy,
                init_tau=max_tau,
                goal=goal,
                max_path_length=args.H,
                animated=not args.hide,
                cycle_tau=True,
                decrement_tau=True,
            )
            paths.append(path)
        env.log_diagnostics(paths)
        for key, value in get_generic_path_information(paths).items():
            logger.record_tabular(key, value)
        logger.dump_tabular()
예제 #26
0
def train_pixelcnn(
    vqvae=None,
    vqvae_path=None,
    num_epochs=100,
    batch_size=32,
    n_layers=15,
    dataset_path=None,
    save=True,
    save_period=10,
    cached_dataset_path=False,
    trainer_kwargs=None,
    model_kwargs=None,
    data_filter_fn=lambda x: x,
    debug=False,
    data_size=float('inf'),
    num_train_batches_per_epoch=None,
    num_test_batches_per_epoch=None,
    train_img_loader=None,
    test_img_loader=None,
):
    trainer_kwargs = {} if trainer_kwargs is None else trainer_kwargs
    model_kwargs = {} if model_kwargs is None else model_kwargs

    # Load VQVAE + Define Args
    if vqvae is None:
        vqvae = load_local_or_remote_file(vqvae_path)
        vqvae.to(ptu.device)
        vqvae.eval()

    root_len = vqvae.root_len
    num_embeddings = vqvae.num_embeddings
    embedding_dim = vqvae.embedding_dim
    cond_size = vqvae.num_embeddings
    imsize = vqvae.imsize
    discrete_size = root_len * root_len
    representation_size = embedding_dim * discrete_size
    input_channels = vqvae.input_channels
    imlength = imsize * imsize * input_channels

    log_dir = logger.get_snapshot_dir()

    # Define data loading info
    new_path = osp.join(log_dir, 'pixelcnn_data.npy')

    def prep_sample_data(cached_path):
        data = load_local_or_remote_file(cached_path).item()
        train_data = data['train']
        test_data = data['test']
        return train_data, test_data

    def encode_dataset(path, object_list):
        data = load_local_or_remote_file(path)
        data = data.item()
        data = data_filter_fn(data)

        all_data = []
        n = min(data["observations"].shape[0], data_size)

        for i in tqdm(range(n)):
            obs = ptu.from_numpy(data["observations"][i] / 255.0)
            latent = vqvae.encode(obs, cont=False)
            all_data.append(latent)

        encodings = ptu.get_numpy(torch.stack(all_data, dim=0))
        return encodings

    if train_img_loader:
        _, test_loader, test_batch_loader = create_conditional_data_loader(
            test_img_loader, 80, vqvae, "test2")  # 80
        _, train_loader, train_batch_loader = create_conditional_data_loader(
            train_img_loader, 2000, vqvae, "train2")  # 2000
    else:
        if cached_dataset_path:
            train_data, test_data = prep_sample_data(cached_dataset_path)
        else:
            train_data = encode_dataset(dataset_path['train'],
                                        None)  # object_list)
            test_data = encode_dataset(dataset_path['test'], None)
        dataset = {'train': train_data, 'test': test_data}
        np.save(new_path, dataset)

        _, _, train_loader, test_loader, _ = \
            rlkit.torch.vae.pixelcnn_utils.load_data_and_data_loaders(new_path, 'COND_LATENT_BLOCK', batch_size)

    #train_dataset = InfiniteBatchLoader(train_loader)
    #test_dataset = InfiniteBatchLoader(test_loader)

    print("Finished loading data")

    model = GatedPixelCNN(num_embeddings,
                          root_len**2,
                          n_classes=representation_size,
                          **model_kwargs).to(ptu.device)
    trainer = PixelCNNTrainer(
        model,
        vqvae,
        batch_size=batch_size,
        **trainer_kwargs,
    )

    print("Starting training")

    BEST_LOSS = 999
    for epoch in range(num_epochs):
        should_save = (epoch % save_period == 0) and (epoch > 0)
        trainer.train_epoch(epoch, train_loader, num_train_batches_per_epoch)
        trainer.test_epoch(epoch, test_loader, num_test_batches_per_epoch)

        test_data = test_batch_loader.random_batch(bz)["x"]
        train_data = train_batch_loader.random_batch(bz)["x"]
        trainer.dump_samples(epoch, test_data, test=True)
        trainer.dump_samples(epoch, train_data, test=False)

        if should_save:
            logger.save_itr_params(epoch, model)

        stats = trainer.get_diagnostics()

        cur_loss = stats["test/loss"]
        if cur_loss < BEST_LOSS:
            BEST_LOSS = cur_loss
            vqvae.set_pixel_cnn(model)
            logger.save_extra_data(vqvae, 'best_vqvae', mode='torch')
        else:
            return vqvae

        for k, v in stats.items():
            logger.record_tabular(k, v)
        logger.dump_tabular()
        trainer.end_epoch(epoch)

    return vqvae
예제 #27
0
    def test_epoch(
        self,
        epoch,
        sample_batch=None,
        key=None,
        save_reconstruction=True,
        save_vae=True,
        from_rl=False,
        save_prefix='r',
        only_train_vae=False,
    ):
        self.model.eval()
        losses = []
        log_probs = []
        triplet_losses = []
        matching_losses = []
        vae_matching_losses = []
        kles = []
        lstm_kles = []
        ae_losses = []
        contrastive_losses = []
        beta = float(self.beta_schedule.get_value(epoch))
        for batch_idx in range(10):
            # print(batch_idx)
            if sample_batch is not None:
                data = sample_batch(self.batch_size, key=key)
                next_obs = data['next_obs']
            else:
                next_obs = self.get_batch(epoch=epoch)

            reconstructions, obs_distribution_params, vae_latent_distribution_params, lstm_latent_encodings = self.model(
                next_obs)
            latent_encodings = lstm_latent_encodings
            vae_mu = vae_latent_distribution_params[0]  # this is lstm inputs
            latent_distribution_params = vae_latent_distribution_params

            triplet_loss = ptu.zeros(1)
            for tri_idx, triplet_type in enumerate(self.triplet_loss_type):
                if triplet_type == 1 and not only_train_vae:
                    triplet_loss += self.triplet_loss_coef[
                        tri_idx] * self.triplet_loss(latent_encodings)
                elif triplet_type == 2 and not only_train_vae:
                    triplet_loss += self.triplet_loss_coef[
                        tri_idx] * self.triplet_loss_2(next_obs)
                elif triplet_type == 3 and not only_train_vae:
                    triplet_loss += self.triplet_loss_coef[
                        tri_idx] * self.triplet_loss_3(next_obs)

            if self.matching_loss_coef > 0 and not only_train_vae:
                matching_loss = self.matching_loss(next_obs)
            else:
                matching_loss = ptu.zeros(1)

            if self.vae_matching_loss_coef > 0:
                matching_loss_vae = self.matching_loss_vae(next_obs)
            else:
                matching_loss_vae = ptu.zeros(1)

            if self.contrastive_loss_coef > 0 and not only_train_vae:
                contrastive_loss = self.contrastive_loss(next_obs)
            else:
                contrastive_loss = ptu.zeros(1)

            log_prob = self.model.logprob(next_obs, obs_distribution_params)
            kle = self.model.kl_divergence(latent_distribution_params)
            lstm_kle = ptu.zeros(1)

            ae_loss = F.mse_loss(
                latent_encodings.view((-1, self.model.representation_size)),
                vae_mu.detach())
            ae_losses.append(ae_loss.item())

            loss = -self.recon_loss_coef * log_prob + beta * kle + \
                        self.matching_loss_coef * matching_loss + self.ae_loss_coef * ae_loss + triplet_loss + \
                            self.vae_matching_loss_coef * matching_loss_vae + self.contrastive_loss_coef * contrastive_loss

            losses.append(loss.item())
            log_probs.append(log_prob.item())
            triplet_losses.append(triplet_loss.item())
            matching_losses.append(matching_loss.item())
            vae_matching_losses.append(matching_loss_vae.item())
            kles.append(kle.item())
            lstm_kles.append(lstm_kle.item())
            contrastive_losses.append(contrastive_loss.item())

            if batch_idx == 0 and save_reconstruction:
                seq_len, batch_size, feature_size = next_obs.shape
                show_obs = next_obs[0][:8]
                reconstructions = reconstructions.view(
                    (seq_len, batch_size, feature_size))[0][:8]
                comparison = torch.cat([
                    show_obs.narrow(start=0, length=self.imlength,
                                    dim=1).contiguous().view(
                                        -1, self.input_channels, self.imsize,
                                        self.imsize).transpose(2, 3),
                    reconstructions.view(
                        -1,
                        self.input_channels,
                        self.imsize,
                        self.imsize,
                    ).transpose(2, 3)
                ])
                save_dir = osp.join(logger.get_snapshot_dir(),
                                    '{}{}.png'.format(save_prefix, epoch))
                save_image(comparison.data.cpu(), save_dir, nrow=8)

        self.eval_statistics['epoch'] = epoch
        self.eval_statistics['test/log prob'] = np.mean(log_probs)
        self.eval_statistics['test/triplet loss'] = np.mean(triplet_losses)
        self.eval_statistics['test/vae matching loss'] = np.mean(
            vae_matching_losses)
        self.eval_statistics['test/matching loss'] = np.mean(matching_losses)
        self.eval_statistics['test/KL'] = np.mean(kles)
        self.eval_statistics['test/lstm KL'] = np.mean(lstm_kles)
        self.eval_statistics['test/loss'] = np.mean(losses)
        self.eval_statistics['test/contrastive loss'] = np.mean(
            contrastive_losses)
        self.eval_statistics['beta'] = beta
        self.eval_statistics['test/ae loss'] = np.mean(ae_losses)

        if not from_rl:
            for k, v in self.eval_statistics.items():
                logger.record_tabular(k, v)
            logger.dump_tabular()
            if save_vae:
                logger.save_itr_params(epoch, self.model)

        torch.cuda.empty_cache()
예제 #28
0
def train(dataset_generator,
          n_start_samples,
          projection=project_samples_square_np,
          n_samples_to_add_per_epoch=1000,
          n_epochs=100,
          z_dim=1,
          hidden_size=32,
          save_period=10,
          append_all_data=True,
          full_variant=None,
          dynamics_noise=0,
          decoder_output_var='learned',
          num_bins=5,
          skew_config=None,
          use_perfect_samples=False,
          use_perfect_density=False,
          vae_reset_period=0,
          vae_kwargs=None,
          use_dataset_generator_first_epoch=True,
          **kwargs):
    """
    Sanitize Inputs
    """
    assert skew_config is not None
    if not (use_perfect_density and use_perfect_samples):
        assert vae_kwargs is not None
    if vae_kwargs is None:
        vae_kwargs = {}

    report = HTMLReport(
        logger.get_snapshot_dir() + '/report.html',
        images_per_row=10,
    )
    dynamics = Dynamics(projection, dynamics_noise)
    if full_variant:
        report.add_header("Variant")
        report.add_text(
            json.dumps(
                ppp.dict_to_safe_json(full_variant, sort=True),
                indent=2,
            ))

    vae, decoder, decoder_opt, encoder, encoder_opt = get_vae(
        decoder_output_var,
        hidden_size,
        z_dim,
        vae_kwargs,
    )
    vae.to(ptu.device)

    epochs = []
    losses = []
    kls = []
    log_probs = []
    hist_heatmap_imgs = []
    vae_heatmap_imgs = []
    sample_imgs = []
    entropies = []
    tvs_to_uniform = []
    entropy_gains_from_reweighting = []
    p_theta = Histogram(num_bins)
    p_new = Histogram(num_bins)

    orig_train_data = dataset_generator(n_start_samples)
    train_data = orig_train_data
    start = time.time()
    for epoch in progressbar(range(n_epochs)):
        p_theta = Histogram(num_bins)
        if epoch == 0 and use_dataset_generator_first_epoch:
            vae_samples = dataset_generator(n_samples_to_add_per_epoch)
        else:
            if use_perfect_samples and epoch != 0:
                # Ideally the VAE = p_new, but in practice, it won't be...
                vae_samples = p_new.sample(n_samples_to_add_per_epoch)
            else:
                vae_samples = vae.sample(n_samples_to_add_per_epoch)
        projected_samples = dynamics(vae_samples)
        if append_all_data:
            train_data = np.vstack((train_data, projected_samples))
        else:
            train_data = np.vstack((orig_train_data, projected_samples))

        p_theta.fit(train_data)
        if use_perfect_density:
            prob = p_theta.compute_density(train_data)
        else:
            prob = vae.compute_density(train_data)
        all_weights = prob_to_weight(prob, skew_config)
        p_new.fit(train_data, weights=all_weights)
        if epoch == 0 or (epoch + 1) % save_period == 0:
            epochs.append(epoch)
            report.add_text("Epoch {}".format(epoch))
            hist_heatmap_img = visualize_histogram(p_theta, skew_config,
                                                   report)
            vae_heatmap_img = visualize_vae(
                vae,
                skew_config,
                report,
                resolution=num_bins,
            )
            sample_img = visualize_vae_samples(
                epoch,
                train_data,
                vae,
                report,
                dynamics,
            )

            visualize_samples(
                p_theta.sample(n_samples_to_add_per_epoch),
                report,
                title="P Theta/RB Samples",
            )
            visualize_samples(
                p_new.sample(n_samples_to_add_per_epoch),
                report,
                title="P Adjusted Samples",
            )
            hist_heatmap_imgs.append(hist_heatmap_img)
            vae_heatmap_imgs.append(vae_heatmap_img)
            sample_imgs.append(sample_img)
            report.save()

            Image.fromarray(
                hist_heatmap_img).save(logger.get_snapshot_dir() +
                                       '/hist_heatmap{}.png'.format(epoch))
            Image.fromarray(
                vae_heatmap_img).save(logger.get_snapshot_dir() +
                                      '/hist_heatmap{}.png'.format(epoch))
            Image.fromarray(sample_img).save(logger.get_snapshot_dir() +
                                             '/samples{}.png'.format(epoch))
        """
        train VAE to look like p_new
        """
        if sum(all_weights) == 0:
            all_weights[:] = 1
        if vae_reset_period > 0 and epoch % vae_reset_period == 0:
            vae, decoder, decoder_opt, encoder, encoder_opt = get_vae(
                decoder_output_var,
                hidden_size,
                z_dim,
                vae_kwargs,
            )
            vae.to(ptu.device)
        vae.fit(train_data, weights=all_weights)
        epoch_stats = vae.get_epoch_stats()

        losses.append(np.mean(epoch_stats['losses']))
        kls.append(np.mean(epoch_stats['kls']))
        log_probs.append(np.mean(epoch_stats['log_probs']))
        entropies.append(p_theta.entropy())
        tvs_to_uniform.append(p_theta.tv_to_uniform())
        entropy_gain = p_new.entropy() - p_theta.entropy()
        entropy_gains_from_reweighting.append(entropy_gain)

        for k in sorted(epoch_stats.keys()):
            logger.record_tabular(k, epoch_stats[k])

        logger.record_tabular("Epoch", epoch)
        logger.record_tabular('Entropy ', p_theta.entropy())
        logger.record_tabular('KL from uniform', p_theta.kl_from_uniform())
        logger.record_tabular('TV to uniform', p_theta.tv_to_uniform())
        logger.record_tabular('Entropy gain from reweight', entropy_gain)
        logger.record_tabular('Total Time (s)', time.time() - start)
        logger.dump_tabular()
        logger.save_itr_params(
            epoch, {
                'vae': vae,
                'train_data': train_data,
                'vae_samples': vae_samples,
                'dynamics': dynamics,
            })

    report.add_header("Training Curves")
    plot_curves(
        [
            ("Training Loss", losses),
            ("KL", kls),
            ("Log Probs", log_probs),
            ("Entropy Gain from Reweighting", entropy_gains_from_reweighting),
        ],
        report,
    )
    plot_curves(
        [
            ("Entropy", entropies),
            ("TV to Uniform", tvs_to_uniform),
        ],
        report,
    )
    report.add_text("Max entropy: {}".format(p_theta.max_entropy()))
    report.save()

    for filename, imgs in [
        ("hist_heatmaps", hist_heatmap_imgs),
        ("vae_heatmaps", vae_heatmap_imgs),
        ("samples", sample_imgs),
    ]:
        video = np.stack(imgs)
        vwrite(
            logger.get_snapshot_dir() + '/{}.mp4'.format(filename),
            video,
        )
        local_gif_file_path = '{}.gif'.format(filename)
        gif_file_path = '{}/{}'.format(logger.get_snapshot_dir(),
                                       local_gif_file_path)
        gif(gif_file_path, video)
        report.add_image(local_gif_file_path, txt=filename, is_url=True)
    report.save()
예제 #29
0
    def _log_stats(self, epoch):
        logger.log(f"Epoch {epoch} finished", with_timestamp=True)

        """
        Replay Buffer
        """
        logger.record_dict(
            self.replay_buffer.get_diagnostics(), prefix="replay_buffer/"
        )

        """
        Trainer
        """
        logger.record_dict(self.trainer.get_diagnostics(), prefix="trainer/")

        """
        Exploration
        """
        logger.record_dict(
            self.expl_data_collector.get_diagnostics(), prefix="exploration/"
        )

        expl_paths = self.expl_data_collector.get_epoch_paths()
        if len(expl_paths) > 0:
            if hasattr(self.expl_env, "get_diagnostics"):
                logger.record_dict(
                    self.expl_env.get_diagnostics(expl_paths),
                    prefix="exploration/",
                )

            logger.record_dict(
                eval_util.get_generic_path_information(expl_paths),
                prefix="exploration/",
            )

        """
        Evaluation
        """
        logger.record_dict(
            self.eval_data_collector.get_diagnostics(),
            prefix="evaluation/",
        )
        eval_paths = self.eval_data_collector.get_epoch_paths()
        if hasattr(self.eval_env, "get_diagnostics"):
            logger.record_dict(
                self.eval_env.get_diagnostics(eval_paths),
                prefix="evaluation/",
            )

        logger.record_dict(
            eval_util.get_generic_path_information(eval_paths),
            prefix="evaluation/",
        )

        """
        Misc
        """
        gt.stamp("logging")
        timings = _get_epoch_timings()
        timings["time/training and exploration (s)"] = self.total_train_expl_time
        logger.record_dict(timings)

        logger.record_tabular("Epoch", epoch)
        logger.dump_tabular(with_prefix=False, with_timestamp=False)
예제 #30
0
 def _log_stats(self, epoch):
     logger.log("Epoch {} finished".format(epoch), with_timestamp=True)
     """
     Replay Buffer
     """
     logger.record_dict(
         self.replay_buffer.get_diagnostics(),
         global_step=epoch,
         prefix="replay_buffer/",
     )
     """
     Trainer
     """
     logger.record_dict(self.trainer.get_diagnostics(),
                        global_step=epoch,
                        prefix="trainer/")
     """
     Exploration
     """
     logger.record_dict(
         self.expl_data_collector.get_diagnostics(),
         global_step=epoch,
         prefix="exploration/",
     )
     expl_paths = self.expl_data_collector.get_epoch_paths()
     if hasattr(self.expl_env, "get_diagnostics"):
         logger.record_dict(
             self.expl_env.get_diagnostics(expl_paths),
             global_step=epoch,
             prefix="exploration/",
         )
     logger.record_dict(
         eval_util.get_generic_path_information(expl_paths),
         global_step=epoch,
         prefix="exploration/",
     )
     """
     Evaluation
     """
     logger.record_dict(
         self.eval_data_collector.get_diagnostics(),
         global_step=epoch,
         prefix="evaluation/",
     )
     eval_paths = self.eval_data_collector.get_epoch_paths()
     if hasattr(self.eval_env, "get_diagnostics"):
         logger.record_dict(
             self.eval_env.get_diagnostics(eval_paths),
             global_step=epoch,
             prefix="evaluation/",
         )
     logger.record_dict(
         eval_util.get_generic_path_information(eval_paths),
         global_step=epoch,
         prefix="evaluation/",
     )
     """
     Misc
     """
     gt.stamp("logging")
     logger.record_dict(_get_epoch_timings(), global_step=epoch)
     logger.record_tabular("Epoch", epoch)
     logger.dump_tabular(with_prefix=False, with_timestamp=False)