def log_diagnostics(self, paths, **kwargs):
        list_of_rewards, terminals, obs, actions, next_obs = split_paths(paths)

        returns = []
        for rewards in list_of_rewards:
            returns.append(np.sum(rewards))
        statistics = OrderedDict()
        statistics.update(
            create_stats_ordered_dict(
                'Undiscounted Returns',
                returns,
            ))
        statistics.update(
            create_stats_ordered_dict(
                'Rewards',
                list_of_rewards,
            ))
        statistics.update(create_stats_ordered_dict(
            'Actions',
            actions,
        ))

        fraction_of_time_on_platform = [o[1] for o in obs]
        statistics['Fraction of time on platform'] = np.mean(
            fraction_of_time_on_platform)

        for key, value in statistics.items():
            logger.record_tabular(key, value)
        return returns
Example #2
0
    def evaluate(self, epoch, eval_paths=None):
        statistics = OrderedDict()
        statistics.update(self.eval_statistics)

        logger.log("Collecting samples for evaluation")
        if eval_paths:
            test_paths = eval_paths
        else:
            test_paths = self.get_eval_paths()
        statistics.update(
            eval_util.get_generic_path_information(
                test_paths,
                stat_prefix="Test",
            ))
        # if len(self._exploration_paths) > 0:
        #     statistics.update(eval_util.get_generic_path_information(
        #         self._exploration_paths, stat_prefix="Exploration",
        #     ))
        if hasattr(self.env, "log_diagnostics"):
            self.env.log_diagnostics(test_paths, logger=logger)
        if hasattr(self.env, "get_diagnostics"):
            statistics.update(self.env.get_diagnostics(test_paths))

        average_returns = eval_util.get_average_returns(test_paths)
        statistics['AverageReturn'] = average_returns
        for key, value in statistics.items():
            logger.record_tabular(key, value)
        self.need_to_update_eval_statistics = True
Example #3
0
    def log_diagnostics(self, paths, **kwargs):
        list_of_rewards, terminals, obs, actions, next_obs = split_paths(paths)

        returns = []
        for rewards in list_of_rewards:
            returns.append(np.sum(rewards))
        last_statistics = OrderedDict()
        last_statistics.update(
            create_stats_ordered_dict(
                'UndiscountedReturns',
                returns,
            ))
        last_statistics.update(
            create_stats_ordered_dict(
                'Rewards',
                list_of_rewards,
            ))
        last_statistics.update(create_stats_ordered_dict(
            'Actions',
            actions,
        ))

        for key, value in last_statistics.items():
            logger.record_tabular(key, value)
        return returns
Example #4
0
    def log_diagnostics(self, paths):
        final_values = []
        final_unclipped_rewards = []
        final_rewards = []
        for path in paths:
            final_value = path["actions"][-1][0]
            final_values.append(final_value)
            score = path["observations"][0][0] * final_value
            final_unclipped_rewards.append(score)
            final_rewards.append(clip_magnitude(score, 1))

        last_statistics = OrderedDict()
        last_statistics.update(
            create_stats_ordered_dict(
                'Final Value',
                final_values,
            ))
        last_statistics.update(
            create_stats_ordered_dict(
                'Unclipped Final Rewards',
                final_unclipped_rewards,
            ))
        last_statistics.update(
            create_stats_ordered_dict(
                'Final Rewards',
                final_rewards,
            ))

        for key, value in last_statistics.items():
            logger.record_tabular(key, value)

        return final_unclipped_rewards
Example #5
0
    def evaluate(self, epoch):
        """
        Perform evaluation for this algorithm.

        :param epoch: The epoch number.
        """
        statistics = OrderedDict()

        train_batch = self.get_batch()
        statistics.update(self._statistics_from_batch(train_batch, "Train"))

        logger.log("Collecting samples for evaluation")
        test_paths = self._sample_eval_paths()
        statistics.update(
            get_generic_path_information(
                test_paths,
                stat_prefix="Test",
            ))
        statistics.update(self._statistics_from_paths(test_paths, "Test"))
        average_returns = get_average_returns(test_paths)
        statistics['AverageReturn'] = average_returns

        statistics['Epoch'] = epoch

        for key, value in statistics.items():
            logger.record_tabular(key, value)

        self.env.log_diagnostics(test_paths)
        logger.dump_tabular(with_prefix=False, with_timestamp=False)
def train_vae(variant, return_data=False):
    from railrl.misc.ml_util import PiecewiseLinearSchedule
    from railrl.torch.vae.vae_trainer import ConvVAETrainer
    from railrl.core import logger
    beta = variant["beta"]
    use_linear_dynamics = variant.get('use_linear_dynamics', False)
    generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn',
                                            generate_vae_dataset)
    variant['generate_vae_dataset_kwargs'][
        'use_linear_dynamics'] = use_linear_dynamics
    train_dataset, test_dataset, info = generate_vae_dataset_fctn(
        variant['generate_vae_dataset_kwargs'])
    if use_linear_dynamics:
        action_dim = train_dataset.data['actions'].shape[2]
    else:
        action_dim = 0
    model = get_vae(variant, action_dim)

    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if 'beta_schedule_kwargs' in variant:
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None

    vae_trainer_class = variant.get('vae_trainer_class', ConvVAETrainer)
    trainer = vae_trainer_class(model,
                                beta=beta,
                                beta_schedule=beta_schedule,
                                **variant['algo_kwargs'])
    save_period = variant['save_period']

    dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False)
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        trainer.train_epoch(epoch, train_dataset)
        trainer.test_epoch(epoch, test_dataset)

        if should_save_imgs:
            trainer.dump_reconstructions(epoch)
            trainer.dump_samples(epoch)
            if dump_skew_debug_plots:
                trainer.dump_best_reconstruction(epoch)
                trainer.dump_worst_reconstruction(epoch)
                trainer.dump_sampling_histogram(epoch)

        stats = trainer.get_diagnostics()
        for k, v in stats.items():
            logger.record_tabular(k, v)
        logger.dump_tabular()
        trainer.end_epoch(epoch)

        if epoch % 50 == 0:
            logger.save_itr_params(epoch, model)
    logger.save_extra_data(model, 'vae.pkl', mode='pickle')
    if return_data:
        return model, train_dataset, test_dataset
    return model
def run_task(variant):
    from railrl.core import logger
    print(variant)
    logger.log("Hello from script")
    logger.log("variant: " + str(variant))
    logger.record_tabular("value", 1)
    logger.dump_tabular()
    logger.log("snapshot_dir:", logger.get_snapshot_dir())
Example #8
0
def simulate_policy(args):
    data = pickle.load(open(args.file, "rb"))
    policy_key = args.policy_type + '/policy'
    if policy_key in data:
        policy = data[policy_key]
    else:
        raise Exception("No policy found in loaded dict. Keys: {}".format(
            data.keys()))

    env_key = args.env_type + '/env'
    if env_key in data:
        env = data[env_key]
    else:
        raise Exception("No environment found in loaded dict. Keys: {}".format(
            data.keys()))

    #robosuite env specific things
    env._wrapped_env.has_renderer = True
    env.reset()
    env.viewer.set_camera(camera_id=0)

    if isinstance(env, RemoteRolloutEnv):
        env = env._wrapped_env
    print("Policy loaded")

    if args.enable_render:
        # some environments need to be reconfigured for visualization
        env.enable_render()
    if args.gpu:
        ptu.set_gpu_mode(True)
    if hasattr(policy, "to"):
        policy.to(ptu.device)
    if hasattr(env, "vae"):
        env.vae.to(ptu.device)

    if args.pause:
        import ipdb
        ipdb.set_trace()
    if isinstance(policy, PyTorchModule):
        policy.train(False)
    paths = []
    while True:
        paths.append(
            deprecated_rollout(
                env,
                policy,
                max_path_length=args.H,
                render=not args.hide,
            ))
        if args.log_diagnostics:
            if hasattr(env, "log_diagnostics"):
                env.log_diagnostics(paths, logger)
            for k, v in eval_util.get_generic_path_information(paths).items():
                logger.record_tabular(k, v)
            logger.dump_tabular()
    def log_diagnostics(self, paths):
        n_goal = len(self.goal_positions)
        goal_reached = [False] * n_goal

        for path in paths:
            last_obs = path["observations"][-1]
            for i, goal in enumerate(self.goal_positions):
                if np.linalg.norm(last_obs - goal) < self.goal_threshold:
                    goal_reached[i] = True

        logger.record_tabular('env:goals_reached', goal_reached.count(True))
Example #10
0
    def log_diagnostics(self, paths):
        statistics = OrderedDict()

        for stat_name in [
                'arm to object distance',
                'object to goal distance',
                'arm to goal distance',
        ]:
            stat = get_stat_in_paths(paths, 'env_infos', stat_name)
            statistics.update(create_stats_ordered_dict(stat_name, stat))

        for key, value in statistics.items():
            logger.record_tabular(key, value)
Example #11
0
def plot_performance(policy, env, nrolls):
    print("max_tau, distance")
    # fixed_goals = [-40, -30, 30, 40]
    fixed_goals = [-5, -3, 3, 5]
    taus = np.arange(10) * 10
    for row, fix_tau in enumerate([True, False]):
        for col, horizon_fixed in enumerate([True, False]):
            plot_num = row + 2 * col + 1
            plt.subplot(2, 2, plot_num)
            for fixed_goal in fixed_goals:
                distances = []
                for max_tau in taus:
                    paths = []
                    for _ in range(nrolls):
                        goal = env.sample_goal_for_rollout()
                        goal[0] = fixed_goal
                        path = multitask_rollout(
                            env,
                            policy,
                            goal,
                            init_tau=max_tau,
                            max_path_length=100 if horizon_fixed else max_tau +
                            1,
                            animated=False,
                            cycle_tau=True,
                            decrement_tau=not fix_tau,
                        )
                        paths.append(path)
                    env.log_diagnostics(paths)
                    for key, value in get_generic_path_information(
                            paths).items():
                        logger.record_tabular(key, value)
                    distance = float(
                        dict(logger._tabular)['Final Distance to goal Mean'])
                    distances.append(distance)

                plt.plot(taus, distances)
                print("line done")
            plt.legend([str(goal) for goal in fixed_goals])
            if fix_tau:
                plt.xlabel("Tau (Horizon-1)")
            else:
                plt.xlabel("Initial tau (=Horizon-1)")
            plt.xlabel("Max tau")
            plt.ylabel("Final distance to goal")
            plt.title("Fix Tau = {}, Horizon Fixed to 100  = {}".format(
                fix_tau,
                horizon_fixed,
            ))
    plt.show()
    plt.savefig('results/iclr2018/cheetah-sweep-tau-eval-5-3.jpg')
Example #12
0
    def train_epoch(self, epoch, batches=100):
        self.model.train()
        losses = []
        kles = []
        mses = []
        beta = self.beta_schedule.get_value(epoch)
        for batch_idx in range(batches):
            data = self.get_batch()
            obs = data['obs']
            next_obs = data['next_obs']
            actions = data['actions']
            self.optimizer.zero_grad()
            recon_batch, mu, logvar = self.model(next_obs)
            mse = self.logprob(recon_batch, next_obs)
            kle = self.kl_divergence(mu, logvar)
            if self.recon_loss_type == 'mse':
                loss = mse + beta * kle
            elif self.recon_loss_type == 'wse':
                wse = self.logprob(recon_batch, next_obs, unorm_weights=self.recon_weights)
                loss = wse + beta * kle
            loss.backward()

            losses.append(loss.data[0])
            mses.append(mse.data[0])
            kles.append(kle.data[0])

            self.optimizer.step()

        logger.record_tabular("train/epoch", epoch)
        logger.record_tabular("train/MSE", np.mean(mses))
        logger.record_tabular("train/KL", np.mean(kles))
        logger.record_tabular("train/loss", np.mean(losses))
Example #13
0
    def evaluate(self, epoch, eval_paths=None):
        statistics = OrderedDict()
        if isinstance(self.epoch_discount_schedule, StatConditionalSchedule):
            table_dict = logger.get_table_dict()
            # rllab converts things to strings for some reason
            value = float(
                table_dict[self.epoch_discount_schedule.statistic_name])
            self.epoch_discount_schedule.update(value)

        if not isinstance(self.epoch_discount_schedule, ConstantSchedule):
            statistics['Discount Factor'] = self.discount

        for key, value in statistics.items():
            logger.record_tabular(key, value)
        super().evaluate(epoch, eval_paths=eval_paths)
Example #14
0
def simulate_policy(args):
    if args.pause:
        import ipdb; ipdb.set_trace()
    data = pickle.load(open(args.file, "rb")) # joblib.load(args.file)
    if 'policy' in data:
        policy = data['policy']
    elif 'evaluation/policy' in data:
        policy = data['evaluation/policy']

    if 'env' in data:
        env = data['env']
    elif 'evaluation/env' in data:
        env = data['evaluation/env']

    if isinstance(env, RemoteRolloutEnv):
        env = env._wrapped_env
    print("Policy loaded")
    if args.gpu:
        ptu.set_gpu_mode(True)
        policy.to(ptu.device)
    else:
        ptu.set_gpu_mode(False)
        policy.to(ptu.device)
    if isinstance(env, VAEWrappedEnv):
        env.mode(args.mode)
    if args.enable_render or hasattr(env, 'enable_render'):
        # some environments need to be reconfigured for visualization
        env.enable_render()
    if args.multitaskpause:
        env.pause_on_goal = True
    if isinstance(policy, PyTorchModule):
        policy.train(False)
    paths = []
    while True:
        paths.append(multitask_rollout(
            env,
            policy,
            max_path_length=args.H,
            render=not args.hide,
            observation_key=data.get('evaluation/observation_key', 'observation'),
            desired_goal_key=data.get('evaluation/desired_goal_key', 'desired_goal'),
        ))
        if hasattr(env, "log_diagnostics"):
            env.log_diagnostics(paths)
        if hasattr(env, "get_diagnostics"):
            for k, v in env.get_diagnostics(paths).items():
                logger.record_tabular(k, v)
        logger.dump_tabular()
Example #15
0
    def log_loss_under_uniform(self, model, data, priority_function_kwargs):
        import torch.nn.functional as F
        log_probs_prior = []
        log_probs_biased = []
        log_probs_importance = []
        kles = []
        mses = []
        for i in range(0, data.shape[0], self.batch_size):
            img = normalize_image(data[i:min(data.shape[0], i +
                                             self.batch_size), :])
            torch_img = ptu.from_numpy(img)
            reconstructions, obs_distribution_params, latent_distribution_params = self.model(
                torch_img)

            priority_function_kwargs['sampling_method'] = 'true_prior_sampling'
            log_p, log_q, log_d = compute_log_p_log_q_log_d(
                model, img, **priority_function_kwargs)
            log_prob_prior = log_d.mean()

            priority_function_kwargs['sampling_method'] = 'biased_sampling'
            log_p, log_q, log_d = compute_log_p_log_q_log_d(
                model, img, **priority_function_kwargs)
            log_prob_biased = log_d.mean()

            priority_function_kwargs['sampling_method'] = 'importance_sampling'
            log_p, log_q, log_d = compute_log_p_log_q_log_d(
                model, img, **priority_function_kwargs)
            log_prob_importance = (log_p - log_q + log_d).mean()

            kle = model.kl_divergence(latent_distribution_params)
            mse = F.mse_loss(torch_img,
                             reconstructions,
                             reduction='elementwise_mean')
            mses.append(mse.item())
            kles.append(kle.item())
            log_probs_prior.append(log_prob_prior.item())
            log_probs_biased.append(log_prob_biased.item())
            log_probs_importance.append(log_prob_importance.item())

        logger.record_tabular("Uniform Data Log Prob (True Prior)",
                              np.mean(log_probs_prior))
        logger.record_tabular("Uniform Data Log Prob (Biased)",
                              np.mean(log_probs_biased))
        logger.record_tabular("Uniform Data Log Prob (Importance)",
                              np.mean(log_probs_importance))
        logger.record_tabular("Uniform Data KL", np.mean(kles))
        logger.record_tabular("Uniform Data MSE", np.mean(mses))
Example #16
0
    def log_diagnostics(self, paths):
        Ntraj = len(paths)
        acts = np.array([traj['actions'] for traj in paths])
        obs = np.array([traj['observations'] for traj in paths])

        state_count = np.sum(obs, axis=1)
        states_visited = np.sum(state_count>0, axis=-1)
        #log states visited
        logger.record_tabular('AvgStatesVisited', np.mean(states_visited))

        #log action block lengths
        traj_idx, _, acts_idx = np.where(acts==1)
        acts_idx = np.array([acts_idx[traj_idx==i] for i in range(Ntraj)])

        if self.zero_reward:
             task_reward = np.array([traj['env_infos']['task_reward'] for traj in paths])
             logger.record_tabular('ZeroedTaskReward', np.mean(np.sum(task_reward, axis=1)))
Example #17
0
    def test_epoch(
        self,
        epoch,
    ):
        self.model.eval()
        val_losses = []
        per_dim_losses = np.zeros((self.num_batches, self.y_train.shape[1]))
        for batch in range(self.num_batches):
            inputs_np, labels_np = self.random_batch(
                self.X_test, self.y_test, batch_size=self.batch_size)
            inputs, labels = ptu.Variable(
                ptu.from_numpy(inputs_np)), ptu.Variable(
                    ptu.from_numpy(labels_np))
            outputs = self.model(inputs)
            loss = self.criterion(outputs, labels)
            val_losses.append(loss.data[0])
            per_dim_loss = np.mean(np.power(ptu.get_numpy(outputs - labels),
                                            2),
                                   axis=0)
            per_dim_losses[batch] = per_dim_loss

        logger.record_tabular("test/epoch", epoch)
        logger.record_tabular("test/loss", np.mean(np.array(val_losses)))
        for i in range(self.y_train.shape[1]):
            logger.record_tabular("test/dim " + str(i) + " loss",
                                  np.mean(per_dim_losses[:, i]))
        logger.dump_tabular()
Example #18
0
    def evaluate(self, epoch, eval_paths=None):
        statistics = OrderedDict()
        statistics.update(self.eval_statistics)

        logger.log("Collecting samples for evaluation")
        if eval_paths:
            test_paths = eval_paths
        else:
            test_paths = self.get_eval_paths()
        statistics.update(
            eval_util.get_generic_path_information(
                test_paths,
                stat_prefix="Test",
            ))
        if len(self._exploration_paths) > 0:
            statistics.update(
                eval_util.get_generic_path_information(
                    self._exploration_paths,
                    stat_prefix="Exploration",
                ))
        if hasattr(self.env, "log_diagnostics"):
            self.env.log_diagnostics(test_paths, logger=logger)
        if hasattr(self.env, "get_diagnostics"):
            statistics.update(self.env.get_diagnostics(test_paths))

        if hasattr(self.eval_policy, "log_diagnostics"):
            self.eval_policy.log_diagnostics(test_paths, logger=logger)
        if hasattr(self.eval_policy, "get_diagnostics"):
            statistics.update(self.eval_policy.get_diagnostics(test_paths))

        process = psutil.Process(os.getpid())
        statistics['RAM Usage (Mb)'] = int(process.memory_info().rss / 1000000)

        statistics['Exploration Policy Noise'] = self._exploration_policy_noise

        average_returns = eval_util.get_average_returns(test_paths)
        statistics['AverageReturn'] = average_returns
        for key, value in statistics.items():
            logger.record_tabular(key, value)
        self.need_to_update_eval_statistics = True
Example #19
0
    def test_epoch(self, epoch, save_network=True, batches=100):
        self.model.eval()
        mses = []
        losses = []
        for batch_idx in range(batches):
            data = self.get_batch(train=False)
            z = data["z"]
            z_proj = data['z_proj']
            z_proj_hat = self.model(z)
            mse = self.mse_loss(z_proj_hat, z_proj)
            loss = mse

            mses.append(mse.data[0])
            losses.append(loss.data[0])

        logger.record_tabular("test/epoch", epoch)
        logger.record_tabular("test/MSE", np.mean(mses))
        logger.record_tabular("test/loss", np.mean(losses))

        logger.dump_tabular()
        if save_network:
            logger.save_itr_params(epoch,
                                   self.model,
                                   prefix='reproj',
                                   save_anyway=True)
Example #20
0
    def train_epoch(self, epoch):
        self.model.train()
        losses = []
        per_dim_losses = np.zeros((self.num_batches, self.y_train.shape[1]))
        for batch in range(self.num_batches):
            inputs_np, labels_np = self.random_batch(
                self.X_train, self.y_train, batch_size=self.batch_size)
            inputs, labels = ptu.Variable(
                ptu.from_numpy(inputs_np)), ptu.Variable(
                    ptu.from_numpy(labels_np))
            self.optimizer.zero_grad()
            outputs = self.model(inputs)
            loss = self.criterion(outputs, labels)
            loss.backward()
            self.optimizer.step()
            losses.append(loss.data[0])
            per_dim_loss = np.mean(np.power(ptu.get_numpy(outputs - labels),
                                            2),
                                   axis=0)
            per_dim_losses[batch] = per_dim_loss

        logger.record_tabular("train/epoch", epoch)
        logger.record_tabular("train/loss", np.mean(np.array(losses)))
        for i in range(self.y_train.shape[1]):
            logger.record_tabular("train/dim " + str(i) + " loss",
                                  np.mean(per_dim_losses[:, i]))
    def log_diagnostics(self, paths):
        target_onehots = []
        for path in paths:
            first_observation = path["observations"][0][:self.n + 1]
            target_onehots.append(first_observation)

        final_predictions = []  # each element has shape (dim)
        nonfinal_predictions = []  # each element has shape (seq_length-1, dim)
        for path in paths:
            actions = path["actions"]
            if self._softmax_action:
                actions = softmax(actions, axis=-1)
            final_predictions.append(actions[-1])
            nonfinal_predictions.append(actions[:-1])
        nonfinal_predictions_sequence_dimension_flattened = np.vstack(
            nonfinal_predictions)  # shape = N X dim
        nonfinal_prob_zero = [
            softmax[0]
            for softmax in nonfinal_predictions_sequence_dimension_flattened
        ]
        final_probs_correct = []
        for final_prediction, target_onehot in zip(final_predictions,
                                                   target_onehots):
            correct_pred_idx = np.argmax(target_onehot)
            final_probs_correct.append(final_prediction[correct_pred_idx])
        final_prob_zero = [softmax[0] for softmax in final_predictions]

        last_statistics = OrderedDict()
        last_statistics.update(
            create_stats_ordered_dict('Final P(correct)', final_probs_correct))
        last_statistics.update(
            create_stats_ordered_dict('Non-final P(zero)', nonfinal_prob_zero))
        last_statistics.update(
            create_stats_ordered_dict('Final P(zero)', final_prob_zero))

        for key, value in last_statistics.items():
            logger.record_tabular(key, value)

        return final_probs_correct
Example #22
0
    def evaluate(self, epoch, exploration_paths):
        """
        Perform evaluation for this algorithm.

        :param epoch: The epoch number.
        :param exploration_paths: List of dicts, each representing a path.
        """
        logger.log("Collecting samples for evaluation")
        paths = self._sample_eval_paths(epoch)
        statistics = OrderedDict()

        statistics.update(self._statistics_from_paths(paths, "Test"))
        statistics.update(self._get_other_statistics())
        statistics.update(self._statistics_from_paths(exploration_paths,
                                                      "Exploration"))

        statistics['AverageReturn'] = get_average_returns(paths)
        statistics['Epoch'] = epoch

        for key, value in statistics.items():
            logger.record_tabular(key, value)

        self.log_diagnostics(paths)
Example #23
0
    def evaluate(self, epoch):
        """
        Perform evaluation for this algorithm.

        :param epoch: The epoch number.
        :param exploration_paths: List of dicts, each representing a path.
        """
        statistics = OrderedDict()
        train_batch = self.get_batch(training=True)
        statistics.update(self._statistics_from_batch(train_batch, "Train"))
        validation_batch = self.get_batch(training=False)
        statistics.update(
            self._statistics_from_batch(validation_batch, "Validation")
        )

        statistics['QF Loss Validation - Train Gap'] = (
            statistics['Validation QF Loss Mean']
            - statistics['Train QF Loss Mean']
        )
        statistics['Epoch'] = epoch
        for key, value in statistics.items():
            logger.record_tabular(key, value)
        logger.dump_tabular(with_prefix=False, with_timestamp=False)
Example #24
0
    def train_epoch(self, epoch, batches=100):
        self.model.train()
        mses = []
        losses = []
        for batch_idx in range(batches):
            data = self.get_batch()
            z = data["z"]
            z_proj = data['z_proj']
            self.optimizer.zero_grad()
            z_proj_hat = self.model(z)
            mse = self.mse_loss(z_proj_hat, z_proj)
            loss = mse
            loss.backward()

            mses.append(mse.data[0])
            losses.append(loss.data[0])

            self.optimizer.step()

        logger.record_tabular("train/epoch", epoch)
        logger.record_tabular("train/MSE", np.mean(mses))
        logger.record_tabular("train/loss", np.mean(losses))
def get_n_train_vae(latent_dim,
                    env,
                    vae_train_epochs,
                    num_image_examples,
                    vae_kwargs,
                    vae_trainer_kwargs,
                    vae_architecture,
                    vae_save_period=10,
                    vae_test_p=.9,
                    decoder_activation='sigmoid',
                    vae_class='VAE',
                    **kwargs):
    env.goal_sampling_mode = 'test'
    image_examples = unnormalize_image(
        env.sample_goals(num_image_examples)['desired_goal'])
    n = int(num_image_examples * vae_test_p)
    train_dataset = ImageObservationDataset(image_examples[:n, :])
    test_dataset = ImageObservationDataset(image_examples[n:, :])

    if decoder_activation == 'sigmoid':
        decoder_activation = torch.nn.Sigmoid()

    vae_class = vae_class.lower()
    if vae_class == 'VAE'.lower():
        vae_class = ConvVAE
    elif vae_class == 'SpatialVAE'.lower():
        vae_class = SpatialAutoEncoder
    else:
        raise RuntimeError("Invalid VAE Class: {}".format(vae_class))

    vae = vae_class(latent_dim,
                    architecture=vae_architecture,
                    decoder_output_activation=decoder_activation,
                    **vae_kwargs)

    trainer = ConvVAETrainer(vae, **vae_trainer_kwargs)

    logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True)
    logger.add_tabular_output('vae_progress.csv',
                              relative_to_snapshot_dir=True)
    for epoch in range(vae_train_epochs):
        should_save_imgs = (epoch % vae_save_period == 0)
        trainer.train_epoch(epoch, train_dataset)
        trainer.test_epoch(epoch, test_dataset)

        if should_save_imgs:
            trainer.dump_reconstructions(epoch)
            trainer.dump_samples(epoch)
        stats = trainer.get_diagnostics()
        for k, v in stats.items():
            logger.record_tabular(k, v)

        logger.dump_tabular()
        trainer.end_epoch(epoch)

        if epoch % 50 == 0:
            logger.save_itr_params(epoch, vae)
    logger.save_extra_data(vae, 'vae.pkl', mode='pickle')
    logger.remove_tabular_output('vae_progress.csv',
                                 relative_to_snapshot_dir=True)
    logger.add_tabular_output('progress.csv', relative_to_snapshot_dir=True)
    return vae
Example #26
0
def train_rfeatures_model(variant, return_data=False):
    from railrl.misc.ml_util import PiecewiseLinearSchedule
    # from railrl.torch.vae.conv_vae import (
    #     ConvVAE, ConvResnetVAE
    # )
    import railrl.torch.vae.conv_vae as conv_vae
    # from railrl.torch.vae.vae_trainer import ConvVAETrainer
    from railrl.launchers.experiments.ashvin.rfeatures.rfeatures_model import TimestepPredictionModel
    from railrl.launchers.experiments.ashvin.rfeatures.rfeatures_trainer import TimePredictionTrainer
    from railrl.core import logger
    import railrl.torch.pytorch_util as ptu
    from railrl.pythonplusplus import identity
    import torch
    output_classes = variant["output_classes"]
    representation_size = variant["representation_size"]
    batch_size = variant["batch_size"]
    variant['dataset_kwargs']["output_classes"] = output_classes
    train_dataset, test_dataset, info = get_data(variant['dataset_kwargs'])

    num_train_workers = variant.get("num_train_workers",
                                    0)  # 0 uses main process (good for pdb)
    train_dataset_loader = InfiniteBatchLoader(
        train_dataset,
        batch_size=batch_size,
        num_workers=num_train_workers,
    )
    test_dataset_loader = InfiniteBatchLoader(
        test_dataset,
        batch_size=batch_size,
    )

    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if variant.get('decoder_activation', None) == 'sigmoid':
        decoder_activation = torch.nn.Sigmoid()
    else:
        decoder_activation = identity
    architecture = variant['model_kwargs'].get('architecture', None)
    if not architecture and variant.get('imsize') == 84:
        architecture = conv_vae.imsize84_default_architecture
    elif not architecture and variant.get('imsize') == 48:
        architecture = conv_vae.imsize48_default_architecture
    variant['model_kwargs']['architecture'] = architecture

    model_class = variant.get('model_class', TimestepPredictionModel)
    model = model_class(
        representation_size,
        decoder_output_activation=decoder_activation,
        output_classes=output_classes,
        **variant['model_kwargs'],
    )
    # model = torch.nn.DataParallel(model)
    model.to(ptu.device)

    variant['trainer_kwargs']['batch_size'] = batch_size
    trainer_class = variant.get('trainer_class', TimePredictionTrainer)
    trainer = trainer_class(
        model,
        **variant['trainer_kwargs'],
    )
    save_period = variant['save_period']

    trainer.dump_trajectory_rewards(
        "initial", dict(train=train_dataset.dataset,
                        test=test_dataset.dataset))

    dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False)
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        trainer.train_epoch(epoch, train_dataset_loader, batches=10)
        trainer.test_epoch(epoch, test_dataset_loader, batches=1)

        if should_save_imgs:
            trainer.dump_reconstructions(epoch)

        trainer.dump_trajectory_rewards(
            epoch, dict(train=train_dataset.dataset,
                        test=test_dataset.dataset), should_save_imgs)

        stats = trainer.get_diagnostics()
        for k, v in stats.items():
            logger.record_tabular(k, v)
        logger.dump_tabular()
        trainer.end_epoch(epoch)

        if epoch % 50 == 0:
            logger.save_itr_params(epoch, model)
    logger.save_extra_data(model, 'vae.pkl', mode='pickle')
    if return_data:
        return model, train_dataset, test_dataset
    return model
Example #27
0
    def log_diagnostics(self, paths):
        reward_fwd = get_stat_in_paths(paths, 'env_infos', 'reward_fwd')
        reward_ctrl = get_stat_in_paths(paths, 'env_infos', 'reward_ctrl')

        logger.record_tabular('AvgRewardDist', np.mean(reward_fwd))
        logger.record_tabular('AvgRewardCtrl', np.mean(reward_ctrl))
        if len(paths) > 0:
            progs = [
                path["observations"][-1][-3] - path["observations"][0][-3]
                for path in paths
            ]
            logger.record_tabular('AverageForwardProgress', np.mean(progs))
            logger.record_tabular('MaxForwardProgress', np.max(progs))
            logger.record_tabular('MinForwardProgress', np.min(progs))
            logger.record_tabular('StdForwardProgress', np.std(progs))
        else:
            logger.record_tabular('AverageForwardProgress', np.nan)
            logger.record_tabular('MaxForwardProgress', np.nan)
            logger.record_tabular('MinForwardProgress', np.nan)
            logger.record_tabular('StdForwardProgress', np.nan)
Example #28
0
    def test_epoch(self, epoch, save_vae=True, **kwargs):
        self.model.eval()
        losses = []
        kles = []
        zs = []

        recon_logging_dict = {
            'MSE': [],
            'WSE': [],
        }
        for k in self.extra_recon_logging:
            recon_logging_dict[k] = []

        beta = self.beta_schedule.get_value(epoch)
        for batch_idx in range(100):
            data = self.get_batch(train=False)
            obs = data['obs']
            next_obs = data['next_obs']
            actions = data['actions']
            recon_batch, mu, logvar = self.model(next_obs)
            mse = self.logprob(recon_batch, next_obs)
            wse = self.logprob(recon_batch, next_obs, unorm_weights=self.recon_weights)
            for k, idx in self.extra_recon_logging.items():
                recon_loss = self.logprob(recon_batch, next_obs, idx=idx)
                recon_logging_dict[k].append(recon_loss.data[0])
            kle = self.kl_divergence(mu, logvar)
            if self.recon_loss_type == 'mse':
                loss = mse + beta * kle
            elif self.recon_loss_type == 'wse':
                loss = wse + beta * kle
            z_data = ptu.get_numpy(mu.cpu())
            for i in range(len(z_data)):
                zs.append(z_data[i, :])
            losses.append(loss.data[0])
            recon_logging_dict['WSE'].append(wse.data[0])
            recon_logging_dict['MSE'].append(mse.data[0])
            kles.append(kle.data[0])
        zs = np.array(zs)
        self.model.dist_mu = zs.mean(axis=0)
        self.model.dist_std = zs.std(axis=0)

        for k in recon_logging_dict:
            logger.record_tabular("/".join(["test", k]), np.mean(recon_logging_dict[k]))
        logger.record_tabular("test/KL", np.mean(kles))
        logger.record_tabular("test/loss", np.mean(losses))
        logger.record_tabular("beta", beta)

        process = psutil.Process(os.getpid())
        logger.record_tabular("RAM Usage (Mb)", int(process.memory_info().rss / 1000000))

        num_active_dims = 0
        for std in self.model.dist_std:
            if std > 0.15:
                num_active_dims += 1
        logger.record_tabular("num_active_dims", num_active_dims)

        logger.dump_tabular()
        if save_vae:
            logger.save_itr_params(epoch, self.model, prefix='vae', save_anyway=True)  # slow...
Example #29
0
 def offline_evaluate(self, epoch):
     for key, value in self.eval_statistics.items():
         logger.record_tabular(key, value)
     self.need_to_update_eval_statistics = True
 def evaluate(self, epoch):
     statistics = OrderedDict()
     for key, value in statistics.items():
         logger.record_tabular(key, value)
     super().evaluate(epoch)