def simulate_policy(args): data = torch.load(osp.join(args.dir, 'params.pkl')) policy = data['evaluation/policy'] env = data['evaluation/env'] variant = json.load(open(osp.join(args.dir, 'variant.json'))) variant = variant['variant'] bg_variant = { 'env_id': variant['env_id'], 'imsize': variant['imsize'], 'init_camera': cameras[variant['env_id']], 'presampled_goals_path': variant['skewfit_variant'].get('presampled_goals_path') } train_bg(bg_variant) imsize = variant['imsize'] env._goal_sampling_mode = 'reset_of_env' if args.gpu: ptu.set_gpu_mode(True) policy.stochastic_policy.to(ptu.device) print("Policy and environment loaded") env.reset() env.reset() env.decode_goals = False from rlkit.util.video import dump_video save_dir = osp.join(args.dir, 'visual-tmp.gif') dump_video( env, policy, save_dir, rollout, horizon=args.H, imsize=imsize, rows=1, columns=8, fps=30, )
def save_video(algo, epoch): if epoch % save_period == 0 or epoch == algo.num_epochs: filename = osp.join( logdir, 'video_{epoch}_env.mp4'.format(epoch=epoch)) dump_video(image_env, policy, filename, rollout_function, **dump_video_kwargs)
def _log_stats(self, epoch): logger.log("Epoch {} finished".format(epoch), with_timestamp=True) """ dump video of policy """ if epoch % self.dump_video_interval == 0: imsize = self.expl_env.imsize env = self.eval_env dump_path = logger.get_snapshot_dir() rollout_func = rollout video_name = "{}_sample_goal.gif".format(epoch) latent_distance_name = "{}_latent_sample_goal.png".format(epoch) dump_video(env, self.eval_data_collector._policy, osp.join(dump_path, video_name), rollout_function=rollout_func, imsize=imsize, horizon=self.max_path_length, rows=1, columns=8) plot_latent_dist(env, self.eval_data_collector._policy, save_name=osp.join(dump_path, latent_distance_name), rollout_function=rollout_func, horizon=self.max_path_length) old_goal_sampling_mode, old_decode_goals = env._goal_sampling_mode, env.decode_goals env._goal_sampling_mode = 'reset_of_env' env.decode_goals = False video_name = "{}_reset.gif".format(epoch) latent_distance_name = "{}_latent_reset.png".format(epoch) dump_video(env, self.eval_data_collector._policy, osp.join(dump_path, video_name), rollout_function=rollout_func, imsize=imsize, horizon=self.max_path_length, rows=1, columns=8) plot_latent_dist(env, self.eval_data_collector._policy, save_name=osp.join(dump_path, latent_distance_name), rollout_function=rollout_func, horizon=self.max_path_length) self.eval_env._goal_sampling_mode = old_goal_sampling_mode self.eval_env.decode_goals = old_decode_goals """ Replay Buffer """ logger.record_dict(self.replay_buffer.get_diagnostics(), prefix='replay_buffer/') """ Trainer """ logger.record_dict(self.trainer.get_diagnostics(), prefix='trainer/') """ Exploration """ logger.record_dict(self.expl_data_collector.get_diagnostics(), prefix='exploration/') expl_paths = self.expl_data_collector.get_epoch_paths() if hasattr(self.expl_env, 'get_diagnostics'): logger.record_dict( self.expl_env.get_diagnostics(expl_paths), prefix='exploration/', ) logger.record_dict( eval_util.get_generic_path_information(expl_paths), prefix="exploration/", ) """ Evaluation """ logger.record_dict( self.eval_data_collector.get_diagnostics(), prefix='evaluation/', ) eval_paths = self.eval_data_collector.get_epoch_paths() if hasattr(self.eval_env, 'get_diagnostics'): logger.record_dict( self.eval_env.get_diagnostics(eval_paths), prefix='evaluation/', ) logger.record_dict( eval_util.get_generic_path_information(eval_paths), prefix="evaluation/", ) """ Misc """ gt.stamp('logging') logger.record_dict(_get_epoch_timings()) logger.record_tabular('Epoch', epoch) logger.dump_tabular(with_prefix=False, with_timestamp=False)