Esempio n. 1
0
    def evaluate(self):
        eval_statistics = OrderedDict()
        self.mlp.eval()
        self.encoder.eval()
        # for i in range(self.min_context_size, self.max_context_size):
        MAX_NUM = 9
        context, mask, input_batch, labels = self._get_batch()
        for i in range(1, MAX_NUM):
            # prep the batches
            # context, mask, input_batch, labels = self._get_training_batch()
            # context, mask, input_batch, labels = self._get_eval_batch(i)

            mask = Variable(
                ptu.from_numpy(
                    np.zeros((self.num_tasks_used_per_update,
                              self.max_context_size, 1))))
            mask[:, :i, :] = 1.0
            z = self.encoder(context, mask)

            repeated_z = z.repeat(
                1,
                self.classification_batch_size_per_task).view(-1, z.size(1))
            mlp_input = torch.cat([input_batch, repeated_z], dim=-1)
            preds = self.mlp(mlp_input)
            class_preds = (preds > 0).type(preds.data.type())
            accuracy = (class_preds == labels).type(torch.FloatTensor).mean()
            eval_statistics['Acc for %d' % i] = np.mean(
                ptu.get_numpy(accuracy))

        for key, value in eval_statistics.items():
            logger.record_tabular(key, value)
        logger.dump_tabular(with_prefix=False, with_timestamp=False)
        # print(np.mean(list(eval_statistics.values())))
        print('INPUT_DIM: %s' % INPUT_DIM)
Esempio n. 2
0
    def evaluate(self, epoch):
        statistics = OrderedDict()
        statistics.update(self.eval_statistics)
        self.eval_statistics = None

        logger.log("Collecting samples for evaluation")
        test_paths = self.eval_sampler.obtain_samples()

        statistics.update(eval_util.get_generic_path_information(
            test_paths, stat_prefix="Test",
        ))
        statistics.update(eval_util.get_generic_path_information(
            self._exploration_paths, stat_prefix="Exploration",
        ))
        if hasattr(self.env, "log_diagnostics"):
            self.env.log_diagnostics(test_paths)

        average_returns = rlkit.core.eval_util.get_average_returns(test_paths)
        statistics['AverageReturn'] = average_returns
        for key, value in statistics.items():
            logger.record_tabular(key, value)

        if self.render_eval_paths:
            self.env.render_paths(test_paths)

        if self.plotter:
            self.plotter.draw()
Esempio n. 3
0
    def log_diagnostics(self, paths):
        final_values = []
        final_unclipped_rewards = []
        final_rewards = []
        for path in paths:
            final_value = path["actions"][-1][0]
            final_values.append(final_value)
            score = path["observations"][0][0] * final_value
            final_unclipped_rewards.append(score)
            final_rewards.append(clip_magnitude(score, 1))

        last_statistics = OrderedDict()
        last_statistics.update(
            create_stats_ordered_dict(
                'Final Value',
                final_values,
            ))
        last_statistics.update(
            create_stats_ordered_dict(
                'Unclipped Final Rewards',
                final_unclipped_rewards,
            ))
        last_statistics.update(
            create_stats_ordered_dict(
                'Final Rewards',
                final_rewards,
            ))

        for key, value in last_statistics.items():
            logger.record_tabular(key, value)

        return final_unclipped_rewards
Esempio n. 4
0
def simulate_policy(args):
    if args.pause:
        import ipdb
        ipdb.set_trace()
    data = pickle.load(open(args.file, "rb"))
    policy = data['policy']
    env = data['env']
    print("Policy and environment loaded")
    if args.gpu:
        ptu.set_gpu_mode(True)
        policy.to(ptu.device)
    if isinstance(env, VAEWrappedEnv):
        env.mode(args.mode)
    if args.enable_render or hasattr(env, 'enable_render'):
        # some environments need to be reconfigured for visualization
        env.enable_render()
    policy.train(False)
    paths = []
    while True:
        paths.append(
            multitask_rollout(
                env,
                policy,
                max_path_length=args.H,
                animated=not args.hide,
                observation_key='observation',
                desired_goal_key='desired_goal',
            ))
        if hasattr(env, "log_diagnostics"):
            env.log_diagnostics(paths)
        if hasattr(env, "get_diagnostics"):
            for k, v in env.get_diagnostics(paths).items():
                logger.record_tabular(k, v)
        logger.dump_tabular()
Esempio n. 5
0
    def evaluate(self, epoch):
        """
        Perform evaluation for this algorithm.

        :param epoch: The epoch number.
        """
        statistics = OrderedDict()

        train_batch = self.get_batch()
        statistics.update(self._statistics_from_batch(train_batch, "Train"))

        logger.log("Collecting samples for evaluation")
        test_paths = self._sample_eval_paths()
        statistics.update(get_generic_path_information(
            test_paths, stat_prefix="Test",
        ))
        statistics.update(self._statistics_from_paths(test_paths, "Test"))
        average_returns = get_average_returns(test_paths)
        statistics['AverageReturn'] = average_returns

        statistics['Epoch'] = epoch

        for key, value in statistics.items():
            logger.record_tabular(key, value)

        self.env.log_diagnostics(test_paths)
        logger.dump_tabular(with_prefix=False, with_timestamp=False)
Esempio n. 6
0
    def evaluate(self, epoch, eval_paths=None):
        statistics = OrderedDict()
        statistics.update(self.eval_statistics)

        logger.log("Collecting samples for evaluation")
        if eval_paths:
            test_paths = eval_paths
        else:
            test_paths = self.get_eval_paths()
        statistics.update(
            eval_util.get_generic_path_information(
                test_paths,
                stat_prefix="Test",
            ))
        if len(self._exploration_paths) > 0:
            statistics.update(
                eval_util.get_generic_path_information(
                    self._exploration_paths,
                    stat_prefix="Exploration",
                ))
        if hasattr(self.env, "log_diagnostics"):
            self.env.log_diagnostics(test_paths, logger=logger)
        if hasattr(self.env, "get_diagnostics"):
            statistics.update(self.env.get_diagnostics(test_paths))

        average_returns = eval_util.get_average_returns(test_paths)
        statistics['AverageReturn'] = average_returns
        for key, value in statistics.items():
            logger.record_tabular(key, value)
        self.need_to_update_eval_statistics = True
Esempio n. 7
0
def simulate_policy(args):
    data = torch.load(args.file)
    policy = data['evaluation/policy']
    env = data['evaluation/env']
    print("Policy loaded")
    if args.gpu:
        set_gpu_mode(True)
        policy.cuda()
    paths = []
    while True:
        path = rollout(
            env,
            policy,
            max_path_length=args.H,
            render=True,
        )
        paths.append(path)
        if hasattr(env, "log_diagnostics"):
            env.log_diagnostics(paths)
        if hasattr(env, "get_diagnostics"):
            for k, v in env.get_diagnostics(paths).items():
                logger.record_tabular(k, v)
        else:
            logger.record_dict(
                eval_util.get_generic_path_information(paths),
                prefix="evaluation/",
            )
        logger.dump_tabular()
Esempio n. 8
0
    def log_diagnostics(self, paths, **kwargs):
        list_of_rewards, terminals, obs, actions, next_obs = split_paths(paths)

        returns = []
        for rewards in list_of_rewards:
            returns.append(np.sum(rewards))
        last_statistics = OrderedDict()
        last_statistics.update(
            create_stats_ordered_dict(
                'UndiscountedReturns',
                returns,
            ))
        last_statistics.update(
            create_stats_ordered_dict(
                'Rewards',
                list_of_rewards,
            ))
        last_statistics.update(create_stats_ordered_dict(
            'Actions',
            actions,
        ))

        for key, value in last_statistics.items():
            logger.record_tabular(key, value)
        return returns
Esempio n. 9
0
    def log_diagnostics(self, paths, **kwargs):
        list_of_rewards, terminals, obs, actions, next_obs = split_paths(paths)

        returns = []
        for rewards in list_of_rewards:
            returns.append(np.sum(rewards))
        statistics = OrderedDict()
        statistics.update(
            create_stats_ordered_dict(
                'Undiscounted Returns',
                returns,
            ))
        statistics.update(
            create_stats_ordered_dict(
                'Rewards',
                list_of_rewards,
            ))
        statistics.update(create_stats_ordered_dict(
            'Actions',
            actions,
        ))

        fraction_of_time_on_platform = [o[1] for o in obs]
        statistics['Fraction of time on platform'] = np.mean(
            fraction_of_time_on_platform)

        for key, value in statistics.items():
            logger.record_tabular(key, value)
        return returns
def simulate_policy(args):
    data = torch.load(args.file)
    policy = data['evaluation/policy']
    env = data['evaluation/env']
    print("Policy and environment loaded")
    if args.gpu:
        ptu.set_gpu_mode(True)
        policy.to(ptu.device)
        print('Using GPU')
    if isinstance(env, VAEWrappedEnv) and hasattr(env, 'mode'):
        env.mode(args.mode)
        print('Set environment mode {}'.format(args.mode))
    if args.enable_render or hasattr(env, 'enable_render'):
        # some environments need to be reconfigured for visualization
        env.enable_render()
    paths = []
    while True:
        paths.append(
            multitask_rollout(
                env,
                policy,
                max_path_length=args.H,
                render=not args.hide,
                observation_key='observation',
                desired_goal_key='desired_goal',
            ))
        if hasattr(env, "log_diagnostics"):
            env.log_diagnostics(paths)
        if hasattr(env, "get_diagnostics"):
            for k, v in env.get_diagnostics(paths).items():
                logger.record_tabular(k, v)
        logger.dump_tabular()
    def evaluate(self):
        eval_statistics = OrderedDict()
        self.mlp.eval()
        self.encoder.eval()
        # for i in range(self.min_context_size, self.max_context_size):
        for i in range(1, 9):
            # prep the batches
            # context, mask, input_batch, labels = self._get_training_batch()
            context, mask, input_batch, labels = self._get_eval_batch(i)
            post_dist = self.encoder(context, mask)

            # z = post_dist.sample() # N_tasks x Dim
            z = post_dist.mean

            repeated_z = z.repeat(1, self.classification_batch_size_per_task).view(-1, z.size(1))
            mlp_input = torch.cat([input_batch, repeated_z], dim=-1)
            preds = self.mlp(mlp_input)
            class_preds = (preds > 0).type(preds.data.type())
            accuracy = (class_preds == labels).type(torch.FloatTensor).mean()
            eval_statistics['Acc for %d' % i] = np.mean(ptu.get_numpy(accuracy))

        for key, value in eval_statistics.items():
            logger.record_tabular(key, value)
        logger.dump_tabular(with_prefix=False, with_timestamp=False)
        # print(np.mean(list(eval_statistics.values())))

        print('NUM_SHAPE: %d' % NUM_SHAPES)
        print('NUM_COLORS: %d' % NUM_COLORS)
        print('NUM_PER_IMAGE: %d' % NUM_PER_IMAGE)
        print('MODE: %s' % MODE)
Esempio n. 12
0
 def _log_stats(self, epoch):
     logger.log("Epoch {} finished".format(epoch), with_timestamp=True)
     """
     Replay Buffer
     """
     logger.record_dict(self.replay_buffer.get_diagnostics(),
                        prefix='replay_buffer/')
     """
     Trainer
     """
     logger.record_dict(self.trainer.get_diagnostics(), prefix='trainer/')
     """
     Exploration
     """
     logger.record_dict(self.expl_data_collector.get_diagnostics(),
                        prefix='exploration/')
     """
     Evaluation
     """
     logger.record_dict(
         self.eval_data_collector.get_diagnostics(),
         prefix='evaluation/',
     )
     eval_paths = self.eval_data_collector.get_epoch_paths()
     logger.record_dict(
         eval_util.get_generic_path_information(eval_paths),
         prefix="evaluation/",
     )
     """
     Misc
     """
     gt.stamp('logging')
     logger.record_dict(_get_epoch_timings())
     logger.record_tabular('Epoch', epoch)
     logger.dump_tabular(with_prefix=False, with_timestamp=False)
    def test_epoch(
            self,
            epoch,
            save_reconstruction=True,
            save_vae=True,
            from_rl=False,
    ):
        self.model.eval()
        losses = []
        log_probs = []
        kles = []
        zs = []
        beta = float(self.beta_schedule.get_value(epoch))
        for batch_idx in range(10):
            next_obs = self.get_batch(train=False)
            reconstructions, obs_distribution_params, latent_distribution_params = self.model(next_obs)
            log_prob = self.model.logprob(next_obs, obs_distribution_params)
            kle = self.model.kl_divergence(latent_distribution_params)
            loss = -1 * log_prob + beta * kle

            encoder_mean = latent_distribution_params[0]
            z_data = ptu.get_numpy(encoder_mean.cpu())
            for i in range(len(z_data)):
                zs.append(z_data[i, :])
            losses.append(loss.item())
            log_probs.append(log_prob.item())
            kles.append(kle.item())

            if batch_idx == 0 and save_reconstruction:
                n = min(next_obs.size(0), 8)
                comparison = torch.cat([
                    next_obs[:n].narrow(start=0, length=self.imlength, dim=1)
                        .contiguous().view(
                        -1, self.input_channels, self.imsize, self.imsize
                    ).transpose(2, 3),
                    reconstructions.view(
                        self.batch_size,
                        self.input_channels,
                        self.imsize,
                        self.imsize,
                    )[:n].transpose(2, 3)
                ])
                save_dir = osp.join(logger.get_snapshot_dir(),
                                    'r%d.png' % epoch)
                save_image(comparison.data.cpu(), save_dir, nrow=n)

        zs = np.array(zs)

        self.eval_statistics['epoch'] = epoch
        self.eval_statistics['test/log prob'] = np.mean(log_probs)
        self.eval_statistics['test/KL'] = np.mean(kles)
        self.eval_statistics['test/loss'] = np.mean(losses)
        self.eval_statistics['beta'] = beta
        if not from_rl:
            for k, v in self.eval_statistics.items():
                logger.record_tabular(k, v)
            logger.dump_tabular()
            if save_vae:
                logger.save_itr_params(epoch, self.model)
Esempio n. 14
0
    def train_epoch(self,
                    epoch,
                    sample_batch=None,
                    batches=100,
                    from_rl=False):
        self.model.train()
        losses = []
        log_probs = []
        kles = []
        m_losses = []
        for batch_idx in range(batches):
            if sample_batch is not None:
                data = sample_batch(self.batch_size)
                next_obs = data['next_obs']
            else:
                next_obs = self.get_batch()
            self.optimizer.zero_grad()
            reconstructions, x_prob_losses, kle_losses, mask_losses, x_hats, masks = self.model(
                next_obs)

            # 1. m outside log
            # x_prob_loss = (sum(x_prob_losses)).sum() / next_obs.shape[0]
            # 2. m inside log
            x_prob_loss = -torch.log(
                sum(x_prob_losses)).sum() / next_obs.shape[0]

            kle_loss = self.beta * sum(kle_losses)
            mask_loss = self.gamma * mask_losses
            loss = x_prob_loss + kle_loss + mask_loss

            self.optimizer.zero_grad()
            loss.backward()
            losses.append(loss.item())
            log_probs.append(x_prob_loss.item())
            kles.append(kle_loss.item())
            m_losses.append(mask_loss.item())

            self.optimizer.step()
            if self.log_interval and batch_idx % self.log_interval == 0:
                print(x_prob_loss.item(), kle_loss.item(), mask_loss.item())
                # print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                #     epoch,
                #     batch_idx,
                #     len(self.train_loader.dataset),
                #     100. * batch_idx / len(self.train_loader),
                #     loss.item() / len(next_obs)))

        if from_rl:
            self.vae_logger_stats_for_rl['Train VAE Epoch'] = epoch
            self.vae_logger_stats_for_rl['Train VAE Log Prob'] = np.mean(
                log_probs)
            self.vae_logger_stats_for_rl['Train VAE KL'] = np.mean(kles)
            self.vae_logger_stats_for_rl['Train VAE Loss'] = np.mean(losses)
        else:
            logger.record_tabular("train/epoch", epoch)
            logger.record_tabular("train/Log Prob", np.mean(log_probs))
            logger.record_tabular("train/KL", np.mean(kles))
            logger.record_tabular("train/loss", np.mean(losses))
            logger.record_tabular('train/mask_loss', np.mean(m_losses))
Esempio n. 15
0
def train_vae(variant, return_data=False):
    from rlkit.misc.ml_util import PiecewiseLinearSchedule
    from rlkit.torch.vae.vae_trainer import ConvVAETrainer
    from rlkit.core import logger
    beta = variant["beta"]
    use_linear_dynamics = variant.get('use_linear_dynamics', False)
    generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn',
                                            generate_vae_dataset)
    variant['generate_vae_dataset_kwargs'][
        'use_linear_dynamics'] = use_linear_dynamics
    train_dataset, test_dataset, info = generate_vae_dataset_fctn(
        variant['generate_vae_dataset_kwargs'])
    if use_linear_dynamics:
        action_dim = train_dataset.data['actions'].shape[2]
    else:
        action_dim = 0
    model = get_vae(variant, action_dim)

    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if 'beta_schedule_kwargs' in variant:
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None

    vae_trainer_class = variant.get('vae_trainer_class', ConvVAETrainer)
    trainer = vae_trainer_class(model,
                                beta=beta,
                                beta_schedule=beta_schedule,
                                **variant['algo_kwargs'])
    save_period = variant['save_period']

    dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False)
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        trainer.train_epoch(epoch, train_dataset)
        trainer.test_epoch(epoch, test_dataset)

        if should_save_imgs:
            trainer.dump_reconstructions(epoch)
            trainer.dump_samples(epoch)
            if dump_skew_debug_plots:
                trainer.dump_best_reconstruction(epoch)
                trainer.dump_worst_reconstruction(epoch)
                trainer.dump_sampling_histogram(epoch)

        stats = trainer.get_diagnostics()
        for k, v in stats.items():
            logger.record_tabular(k, v)
        logger.dump_tabular()
        trainer.end_epoch(epoch)

        if epoch % 50 == 0:
            logger.save_itr_params(epoch, model)
    logger.save_extra_data(model, 'vae.pkl', mode='pickle')
    if return_data:
        return model, train_dataset, test_dataset
    return model
Esempio n. 16
0
def run_task(variant):
    from rlkit.core import logger
    print(variant)
    logger.log("Hello from script")
    logger.log("variant: " + str(variant))
    logger.record_tabular("value", 1)
    logger.dump_tabular()
    logger.log("snapshot_dir:", logger.get_snapshot_dir())
    def evaluate(self, epoch):
        statistics = OrderedDict()
        statistics.update(self.eval_statistics)
        self.eval_statistics = None

        # statistics.update(eval_util.get_generic_path_information(
        #     self._exploration_paths, stat_prefix="Exploration",
        # ))

        for mode in ['meta_train', 'meta_test']:
            logger.log("Collecting samples for evaluation")
            test_paths = self.obtain_eval_samples(epoch, mode=mode)

            statistics.update(
                eval_util.get_generic_path_information(
                    test_paths,
                    stat_prefix="Test " + mode,
                ))
            # print(statistics.keys())
            if hasattr(self.env, "log_diagnostics"):
                self.env.log_diagnostics(test_paths)
            if hasattr(self.env, "log_statistics"):
                log_stats = self.env.log_statistics(test_paths)
                new_log_stats = OrderedDict(
                    (k + ' ' + mode, v) for k, v in log_stats.items())
                statistics.update(new_log_stats)

            average_returns = rlkit.core.eval_util.get_average_returns(
                test_paths)
            statistics['AverageReturn ' + mode] = average_returns

            if self.render_eval_paths:
                self.env.render_paths(test_paths)

        # meta_test_this_epoch = statistics['Percent_Solved meta_test']
        # meta_test_this_epoch = statistics['Percent_Solved meta_test']
        meta_test_this_epoch = statistics['AverageReturn meta_test']
        if meta_test_this_epoch >= self.best_meta_test:
            # make sure you set save_algorithm to true then call save_extra_data
            prev_save_alg = self.save_algorithm
            self.save_algorithm = True
            if self.save_best:
                if epoch > self.save_best_after_epoch:
                    temp = self.replay_buffer
                    self.replay_buffer = None
                    logger.save_extra_data(self.get_extra_data_to_save(epoch),
                                           'best_meta_test.pkl')
                    self.replay_buffer = temp
                    self.best_meta_test = meta_test_this_epoch
                    print('\n\nSAVED ALG AT EPOCH %d\n\n' % epoch)
            self.save_algorithm = prev_save_alg

        for key, value in statistics.items():
            logger.record_tabular(key, value)

        if self.plotter:
            self.plotter.draw()
Esempio n. 18
0
    def _log_stats(self, epoch):
        logger.log("Epoch {} finished".format(epoch), with_timestamp=True)
        """
        Replay Buffer
        """
        logger.record_dict(self.replay_buffer.get_diagnostics(),
                           prefix='replay_buffer/')

        # If you want to save replay buffer as a whole, use this
        snap_shot_dir = logger.get_snapshot_dir()
        self.replay_buffer.save_buffer(snap_shot_dir + '/online_buffer.hdf5')
        """
        Trainer
        """
        logger.record_dict(self.trainer.get_diagnostics(), prefix='trainer/')
        """
        Exploration
        """
        logger.record_dict(self.expl_data_collector.get_diagnostics(),
                           prefix='exploration/')
        expl_paths = self.expl_data_collector.get_epoch_paths()
        # import ipdb; ipdb.set_trace()
        if hasattr(self.expl_env, 'get_diagnostics'):
            logger.record_dict(
                self.expl_env.get_diagnostics(expl_paths),
                prefix='exploration/',
            )
        if not self.batch_rl or self.eval_both:
            logger.record_dict(
                eval_util.get_generic_path_information(expl_paths),
                prefix="exploration/",
            )
        """
        Evaluation
        """
        logger.record_dict(
            self.eval_data_collector.get_diagnostics(),
            prefix='evaluation/',
        )
        eval_paths = self.eval_data_collector.get_epoch_paths()
        if hasattr(self.eval_env, 'get_diagnostics'):
            logger.record_dict(
                self.eval_env.get_diagnostics(eval_paths),
                prefix='evaluation/',
            )
        logger.record_dict(
            eval_util.get_generic_path_information(eval_paths),
            prefix="evaluation/",
        )
        """
        Misc
        """
        gt.stamp('logging')
        logger.record_dict(_get_epoch_timings())
        logger.record_tabular('Epoch', epoch)
        logger.dump_tabular(with_prefix=False, with_timestamp=False)
Esempio n. 19
0
    def log_diagnostics(self, paths):
        n_goal = len(self.goal_positions)
        goal_reached = [False] * n_goal

        for path in paths:
            last_obs = path["observations"][-1]
            for i, goal in enumerate(self.goal_positions):
                if np.linalg.norm(last_obs - goal) < self.goal_threshold:
                    goal_reached[i] = True

        logger.record_tabular('env:goals_reached', goal_reached.count(True))
Esempio n. 20
0
def simulate_policy(args):
    data = pickle.load(open(args.file, "rb"))
    policy_key = args.policy_type + '/policy'
    if policy_key in data:
        policy = data[policy_key]
    else:
        raise Exception("No policy found in loaded dict. Keys: {}".format(
            data.keys()))

    env_key = args.env_type + '/env'
    if env_key in data:
        env = data[env_key]
    else:
        raise Exception("No environment found in loaded dict. Keys: {}".format(
            data.keys()))

    if isinstance(env, RemoteRolloutEnv):
        env = env._wrapped_env
    print("Policy loaded")

    if args.enable_render:
        # some environments need to be reconfigured for visualization
        env.enable_render()
    if args.gpu:
        ptu.set_gpu_mode(True)
    if hasattr(policy, "to"):
        policy.to(ptu.device)
    if hasattr(env, "vae"):
        env.vae.to(ptu.device)

    if args.deterministic:
        policy = MakeDeterministic(policy)

    if args.pause:
        import ipdb
        ipdb.set_trace()
    if isinstance(policy, PyTorchModule):
        policy.train(False)
    paths = []
    while True:
        paths.append(
            rollout(
                env,
                policy,
                max_path_length=args.H,
                render=not args.hide,
            ))
        if args.log_diagnostics:
            if hasattr(env, "log_diagnostics"):
                env.log_diagnostics(paths, logger)
            for k, v in eval_util.get_generic_path_information(paths).items():
                logger.record_tabular(k, v)
            logger.dump_tabular()
Esempio n. 21
0
def simulate_policy(args):
    if args.pause:
        import ipdb
        ipdb.set_trace()
    data = pickle.load(open(args.file, "rb"))  # joblib.load(args.file)
    if 'policy' in data:
        policy = data['policy']
    elif 'evaluation/policy' in data:
        policy = data['evaluation/policy']

    if 'env' in data:
        env = data['env']
    elif 'evaluation/env' in data:
        env = data['evaluation/env']

    if isinstance(env, RemoteRolloutEnv):
        env = env._wrapped_env
    print("Policy loaded")
    if args.gpu:
        ptu.set_gpu_mode(True)
        policy.to(ptu.device)
    else:
        ptu.set_gpu_mode(False)
        policy.to(ptu.device)
    if isinstance(env, VAEWrappedEnv):
        env.mode(args.mode)
    if args.enable_render or hasattr(env, 'enable_render'):
        # some environments need to be reconfigured for visualization
        env.enable_render()
    if args.multitaskpause:
        env.pause_on_goal = True
    if isinstance(policy, PyTorchModule):
        policy.train(False)
    paths = []
    while True:
        paths.append(
            multitask_rollout(
                env,
                policy,
                max_path_length=args.H,
                render=not args.hide,
                observation_key=data.get('evaluation/observation_key',
                                         'observation'),
                desired_goal_key=data.get('evaluation/desired_goal_key',
                                          'desired_goal'),
            ))
        if hasattr(env, "log_diagnostics"):
            env.log_diagnostics(paths)
        if hasattr(env, "get_diagnostics"):
            for k, v in env.get_diagnostics(paths).items():
                logger.record_tabular(k, v)
        logger.dump_tabular()
Esempio n. 22
0
    def log_loss_under_uniform(self, model, data, priority_function_kwargs):
        import torch.nn.functional as F

        log_probs_prior = []
        log_probs_biased = []
        log_probs_importance = []
        kles = []
        mses = []
        for i in range(0, data.shape[0], self.batch_size):
            img = normalize_image(data[i : min(data.shape[0], i + self.batch_size), :])
            torch_img = ptu.from_numpy(img)
            reconstructions, obs_distribution_params, latent_distribution_params = self.model(
                torch_img
            )

            priority_function_kwargs["sampling_method"] = "true_prior_sampling"
            log_p, log_q, log_d = compute_log_p_log_q_log_d(
                model, img, **priority_function_kwargs
            )
            log_prob_prior = log_d.mean()

            priority_function_kwargs["sampling_method"] = "biased_sampling"
            log_p, log_q, log_d = compute_log_p_log_q_log_d(
                model, img, **priority_function_kwargs
            )
            log_prob_biased = log_d.mean()

            priority_function_kwargs["sampling_method"] = "importance_sampling"
            log_p, log_q, log_d = compute_log_p_log_q_log_d(
                model, img, **priority_function_kwargs
            )
            log_prob_importance = (log_p - log_q + log_d).mean()

            kle = model.kl_divergence(latent_distribution_params)
            mse = F.mse_loss(torch_img, reconstructions, reduction="elementwise_mean")
            mses.append(mse.item())
            kles.append(kle.item())
            log_probs_prior.append(log_prob_prior.item())
            log_probs_biased.append(log_prob_biased.item())
            log_probs_importance.append(log_prob_importance.item())

        logger.record_tabular(
            "Uniform Data Log Prob (True Prior)", np.mean(log_probs_prior)
        )
        logger.record_tabular(
            "Uniform Data Log Prob (Biased)", np.mean(log_probs_biased)
        )
        logger.record_tabular(
            "Uniform Data Log Prob (Importance)", np.mean(log_probs_importance)
        )
        logger.record_tabular("Uniform Data KL", np.mean(kles))
        logger.record_tabular("Uniform Data MSE", np.mean(mses))
Esempio n. 23
0
    def _try_to_eval(self, epoch=0):
        if epoch % self.save_extra_data_interval == 0:
            logger.save_extra_data(self.get_extra_data_to_save(epoch), epoch)
        if self._can_evaluate():
            self.evaluate(epoch)

            params = self.get_epoch_snapshot(epoch)
            logger.save_itr_params(epoch, params)
            table_keys = logger.get_table_key_set()
            if self._old_table_keys is not None:
                assert table_keys == self._old_table_keys, (
                    "Table keys cannot change from iteration to iteration.")
            self._old_table_keys = table_keys

            logger.record_tabular(
                "Number of train steps total",
                self._n_train_steps_total,
            )
            logger.record_tabular(
                "Number of env steps total",
                self._n_env_steps_total,
            )
            logger.record_tabular(
                "Number of rollouts total",
                self._n_rollouts_total,
            )

            logger.record_tabular("Epoch", epoch)
            logger.dump_tabular(with_prefix=False, with_timestamp=False)
        else:
            logger.log("Skipping eval for now.")
Esempio n. 24
0
    def evaluate(self, epoch):
        """
        Evaluate the policy, e.g. save/print progress.
        :param epoch:
        :return:
        """
        statistics = OrderedDict()
        try:
            statistics.update(self.eval_statistics)
            self.eval_statistics = None
        except:
            print('No Stats to Eval')

        logger.log("Collecting samples for evaluation")
        test_paths = self.eval_sampler.obtain_samples()

        statistics.update(
            eval_util.get_generic_path_information(
                test_paths,
                stat_prefix="Test",
            ))
        statistics.update(
            eval_util.get_generic_path_information(
                self._exploration_paths,
                stat_prefix="Exploration",
            ))

        if hasattr(self.env, "log_diagnostics"):
            self.env.log_diagnostics(test_paths)
        if hasattr(self.env, "log_statistics"):
            statistics.update(self.env.log_statistics(test_paths))
        if epoch % self.freq_log_visuals == 0:
            if hasattr(self.env, "log_visuals"):
                self.env.log_visuals(test_paths, epoch,
                                     logger.get_snapshot_dir())

        average_returns = eval_util.get_average_returns(test_paths)
        statistics['AverageReturn'] = average_returns
        for key, value in statistics.items():
            logger.record_tabular(key, value)

        best_statistic = statistics[self.best_key]
        if best_statistic > self.best_statistic_so_far:
            self.best_statistic_so_far = best_statistic
            if self.save_best and epoch >= self.save_best_starting_from_epoch:
                data_to_save = {'epoch': epoch, 'statistics': statistics}
                data_to_save.update(self.get_epoch_snapshot(epoch))
                logger.save_extra_data(data_to_save, 'best.pkl')
                print('\n\nSAVED BEST\n\n')
Esempio n. 25
0
    def evaluate(self, epoch, eval_paths=None):
        statistics = OrderedDict()
        if isinstance(self.epoch_discount_schedule, StatConditionalSchedule):
            table_dict = logger.get_table_dict()
            # rllab converts things to strings for some reason
            value = float(
                table_dict[self.epoch_discount_schedule.statistic_name])
            self.epoch_discount_schedule.update(value)

        if not isinstance(self.epoch_discount_schedule, ConstantSchedule):
            statistics['Discount Factor'] = self.discount

        for key, value in statistics.items():
            logger.record_tabular(key, value)
        super().evaluate(epoch, eval_paths=eval_paths)
Esempio n. 26
0
    def log_diagnostics(self, paths):
        Ntraj = len(paths)
        acts = np.array([traj['actions'] for traj in paths])
        obs = np.array([traj['observations'] for traj in paths])

        state_count = np.sum(obs, axis=1)
        states_visited = np.sum(state_count>0, axis=-1)
        #log states visited
        logger.record_tabular('AvgStatesVisited', np.mean(states_visited))

        #log action block lengths
        traj_idx, _, acts_idx = np.where(acts==1)
        acts_idx = np.array([acts_idx[traj_idx==i] for i in range(Ntraj)])

        if self.zero_reward:
             task_reward = np.array([traj['env_infos']['task_reward'] for traj in paths])
             logger.record_tabular('ZeroedTaskReward', np.mean(np.sum(task_reward, axis=1)))
Esempio n. 27
0
    def log_diagnostics(self, paths):
        statistics = OrderedDict()

        for stat_name in [
            'arm to object distance',
            'object to goal distance',
            'arm to goal distance',
        ]:
            stat = get_stat_in_paths(
                paths, 'env_infos', stat_name
            )
            statistics.update(create_stats_ordered_dict(
                stat_name, stat
            ))

        for key, value in statistics.items():
            logger.record_tabular(key, value)
Esempio n. 28
0
    def train_epoch(self, epoch, batches=20, from_rl=False):
        self.model.train()
        losses = []
        log_probs = []
        kles = []
        mses = []

        for batch_idx in range(batches):
            next_obs, actions = self.get_batch()
            self.optimizer.zero_grad()

            # schedule for doing refinement or physics
            # refinement = 0, physics = 1
            # when only doing refinement predict same image
            # when only doing physics predict next image
            schedule = np.random.randint(0, 2, (self.model.T, ))
            schedule[:4] = 0
            #import pdb; pdb.set_trace()
            x_hat, mask, loss, kle_loss, x_prob_loss, mse, final_recon = self.model(
                next_obs, actions=actions, schedule=schedule)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(
                [x for x in self.model.parameters()], 5.0)
            #torch.nn.utils.clip_grad_norm_(self.model.lambdas, 5.0)  # TODO Clip other gradients?
            self.optimizer.step()

            losses.append(loss.item())
            log_probs.append(x_prob_loss.item())
            kles.append(kle_loss.item())
            mses.append(mse.item())

            if self.log_interval and batch_idx % self.log_interval == 0:
                print(x_prob_loss.item(), kle_loss.item())

        if from_rl:
            self.vae_logger_stats_for_rl['Train VAE Epoch'] = epoch
            self.vae_logger_stats_for_rl['Train VAE Log Prob'] = np.mean(
                log_probs)
            self.vae_logger_stats_for_rl['Train VAE KL'] = np.mean(kles)
            self.vae_logger_stats_for_rl['Train VAE Loss'] = np.mean(losses)
        else:
            logger.record_tabular("train/epoch", epoch)
            logger.record_tabular("train/Log Prob", np.mean(log_probs))
            logger.record_tabular("train/KL", np.mean(kles))
            logger.record_tabular("train/loss", np.mean(losses))
            logger.record_tabular("train/mse", np.mean(mses))
Esempio n. 29
0
    def _update_beta(self):
        if self.replay_buffer._size > self.beta_batch_size:
            batch_beta = self.get_batch_custom(self.beta_batch_size)
            rewards_beta = batch_beta['rewards']
            terminals_beta = batch_beta['terminals']
            obs_beta = batch_beta['observations']
            actions_beta = batch_beta['actions']
            next_obs_beta = batch_beta['next_observations']
            with torch.no_grad():
                q_pred_beta = self.qf(obs_beta, actions_beta)
                v_pred_beta = self.vf(next_obs_beta)
                q_target_beta = rewards_beta + (
                    1. - terminals_beta) * self.discount * v_pred_beta
                self.beta = self.qf_criterion(q_pred_beta, q_target_beta)
        else:
            self.beta = 1.0

        logger.record_tabular("beta", ptu.FloatTensor([self.beta])[0].item())
Esempio n. 30
0
    def evaluate(self):
        eval_statistics = OrderedDict()
        self.mlp.eval()
        self.encoder.eval()
        # for i in range(self.min_context_size, self.max_context_size+1):
        for i in range(1, 12):
            # prep the batches
            context_batch, mask, obs_task_params, classification_inputs, classification_labels = self._get_eval_batch(
                self.num_tasks_per_eval, i)
            # print(len(context_batch))
            # print(len(context_batch[0]))

            post_dist = self.encoder(context_batch, mask)
            z = post_dist.sample()  # N_tasks x Dim
            # z = post_dist.mean

            obs_task_params = Variable(ptu.from_numpy(obs_task_params))
            # print(obs_task_params)

            if self.training_regression:
                preds = self.mlp(z)
                loss = self.mse(preds, obs_task_params)
                eval_statistics['Loss for %d' % i] = np.mean(
                    ptu.get_numpy(loss))
            else:
                repeated_z = z.repeat(
                    1, self.classification_batch_size_per_task).view(
                        -1, z.size(1))
                mlp_input = torch.cat([classification_inputs, repeated_z],
                                      dim=-1)
                preds = self.mlp(mlp_input)
                # loss = self.bce(preds, classification_labels)
                class_preds = (preds > 0).type(preds.data.type())
                accuracy = (class_preds == classification_labels).type(
                    torch.FloatTensor).mean()
                eval_statistics['Acc for %d' % i] = np.mean(
                    ptu.get_numpy(accuracy))

        for key, value in eval_statistics.items():
            logger.record_tabular(key, value)
        logger.dump_tabular(with_prefix=False, with_timestamp=False)